From 48586c8d4db5c00d3d4b9dabcc9a5d2294c5b15d Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Sun, 04 Mar 2018 15:15:56 +0000
Subject: [PATCH] Compile fix

---
 src/blas_kernels.cu |   32 ++++++++++++++++++--------------
 1 files changed, 18 insertions(+), 14 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index d940176..8e1cf19 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -23,7 +23,7 @@
     dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
     dim3 dimBlock(BLOCK, 1, 1);
 
-    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    scale_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
     check_error(cudaPeekAtLastError());
 }
 
@@ -67,7 +67,7 @@
     dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
     dim3 dimBlock(BLOCK, 1, 1);
 
-    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    add_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
     check_error(cudaPeekAtLastError());
 }
 
@@ -223,6 +223,7 @@
             local[id] += (i+id < spatial) ? delta[index] : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         mean_delta[filter] = 0;
@@ -251,6 +252,7 @@
             local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         variance_delta[filter] = 0;
@@ -425,7 +427,7 @@
 extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     size_t N = batch*filters*spatial;
-    normalize_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, batch, filters, spatial);
+    normalize_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, x, mean, variance, batch, filters, spatial);
     check_error(cudaPeekAtLastError());
 }
 
@@ -446,6 +448,7 @@
             local[id] += (i+id < spatial) ? x[index] : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         mean[filter] = 0;
@@ -474,6 +477,7 @@
             local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         variance[filter] = 0;
@@ -486,13 +490,13 @@
 
 extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
 {
-    fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
+    fast_mean_kernel<<<filters, BLOCK, 0, get_cuda_stream()>>>(x, batch, filters, spatial, mean);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
 {
-    fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
+    fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
     check_error(cudaPeekAtLastError());
 }
 
@@ -516,13 +520,13 @@
 
 extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
 {
-    pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
+    pow_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX, Y, INCY);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
 {
-    axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
+    axpy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
     check_error(cudaPeekAtLastError());
 }
 
@@ -539,7 +543,7 @@
 
 extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
 {
-    copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
+    copy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
     check_error(cudaPeekAtLastError());
 }
 
@@ -563,20 +567,20 @@
 extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
 {
     int size = spatial*batch*layers;
-    flatten_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, spatial, layers, batch, forward, out);
+    flatten_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, spatial, layers, batch, forward, out);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
 {
     int size = w*h*c*batch;
-    reorg_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, w, h, c, batch, stride, forward, out);
+    reorg_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, w, h, c, batch, stride, forward, out);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
 {
-    mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
+    mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
     check_error(cudaPeekAtLastError());
 }
 
@@ -595,7 +599,7 @@
 
 extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
 {
-    scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    scal_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
     check_error(cudaPeekAtLastError());
 }
 
@@ -607,7 +611,7 @@
 
 extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
 {
-    fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    fill_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
     check_error(cudaPeekAtLastError());
 }
 
@@ -762,6 +766,6 @@
 {
     int inputs = n;
     int batch = groups;
-    softmax_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, offset, batch, input, temp, output);
+    softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
     check_error(cudaPeekAtLastError());
 }

--
Gitblit v1.10.0