From e92f7d301c971b4d27aa3dcd1e4047e94f04b3fc Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 25 Mar 2015 01:27:12 +0000
Subject: [PATCH] smaller gridsize in bias

---
 src/convolutional_kernels.cu |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index a9a6837..5b49091 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -11,16 +11,16 @@
 __global__ void bias_output_kernel(float *output, float *biases, int n, int size)
 {
     int offset = blockIdx.x * blockDim.x + threadIdx.x;
-    int filter = blockIdx.y;
-    int batch = blockIdx.z;
+    int filter = blockIdx.y % n;
+    int batch = blockIdx.y / n;
 
     if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
 }
 
 void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
 {
+    dim3 dimGrid((size-1)/BLOCK + 1, n*batch, 1);
     dim3 dimBlock(BLOCK, 1, 1);
-    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
 
     bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
     check_error(cudaPeekAtLastError());

--
Gitblit v1.10.0