Joseph Redmon
2016-11-19 62235e9aa3d0c15d87d49bf340625d075cba3e65
src/maxpool_layer_kernels.cu
@@ -9,8 +9,8 @@
__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
{
    int h = (in_h + 2*pad - size + 1)/stride + 1;
    int w = (in_w + 2*pad - size + 1)/stride + 1;
    int h = (in_h + 2*pad)/stride;
    int w = (in_w + 2*pad)/stride;
    int c = in_c;
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -49,8 +49,8 @@
__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
{
    int h = (in_h + 2*pad - size + 1)/stride + 1;
    int w = (in_w + 2*pad - size + 1)/stride + 1;
    int h = (in_h + 2*pad)/stride;
    int w = (in_w + 2*pad)/stride;
    int c = in_c;
    int area = (size-1)/stride;