Joseph Redmon
2015-03-25 e92f7d301c971b4d27aa3dcd1e4047e94f04b3fc
smaller gridsize in bias
1 files modified
6 ■■■■ changed files
src/convolutional_kernels.cu 6 ●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu
@@ -11,16 +11,16 @@
__global__ void bias_output_kernel(float *output, float *biases, int n, int size)
{
    int offset = blockIdx.x * blockDim.x + threadIdx.x;
    int filter = blockIdx.y;
    int batch = blockIdx.z;
    int filter = blockIdx.y % n;
    int batch = blockIdx.y / n;
    if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
}
void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
{
    dim3 dimGrid((size-1)/BLOCK + 1, n*batch, 1);
    dim3 dimBlock(BLOCK, 1, 1);
    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
    bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
    check_error(cudaPeekAtLastError());