From c6ecf1e0420737eafeb99b27b1d716b46a6cbb7a Mon Sep 17 00:00:00 2001
From: Jud White <github@judsonwhite.com>
Date: Sun, 25 Mar 2018 20:41:48 +0000
Subject: [PATCH] README.md: add notes to How to compile on Windows

---
 src/maxpool_layer_kernels.cu |   38 +++++++++++++++++++++-----------------
 1 files changed, 21 insertions(+), 17 deletions(-)

diff --git a/src/maxpool_layer_kernels.cu b/src/maxpool_layer_kernels.cu
index a5c8209..d40d3c0 100644
--- a/src/maxpool_layer_kernels.cu
+++ b/src/maxpool_layer_kernels.cu
@@ -1,12 +1,16 @@
+#include "cuda_runtime.h"
+#include "curand.h"
+#include "cublas_v2.h"
+
 extern "C" {
 #include "maxpool_layer.h"
 #include "cuda.h"
 }
 
-__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, float *input, float *output, int *indexes)
+__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
 {
-    int h = (in_h-1)/stride + 1;
-    int w = (in_w-1)/stride + 1;
+    int h = (in_h + 2*pad)/stride;
+    int w = (in_w + 2*pad)/stride;
     int c = in_c;
 
     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -20,8 +24,8 @@
     id /= c;
     int b = id;
 
-    int w_offset = (-size-1)/2 + 1;
-    int h_offset = (-size-1)/2 + 1;
+    int w_offset = -pad;
+    int h_offset = -pad;
 
     int out_index = j + w*(i + h*(k + c*b));
     float max = -INFINITY;
@@ -43,10 +47,10 @@
     indexes[out_index] = max_i;
 }
 
-__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, float *delta, float *prev_delta, int *indexes)
+__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
 {
-    int h = (in_h-1)/stride + 1;
-    int w = (in_w-1)/stride + 1;
+    int h = (in_h + 2*pad)/stride;
+    int w = (in_w + 2*pad)/stride;
     int c = in_c;
     int area = (size-1)/stride;
 
@@ -62,8 +66,8 @@
     id /= in_c;
     int b = id;
 
-    int w_offset = (-size-1)/2 + 1;
-    int h_offset = (-size-1)/2 + 1;
+    int w_offset = -pad;
+    int h_offset = -pad;
 
     float d = 0;
     int l, m;
@@ -77,26 +81,26 @@
             d += (valid && indexes[out_index] == index) ? delta[out_index] : 0;
         }
     }
-    prev_delta[index] = d;
+    prev_delta[index] += d;
 }
 
-extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, float *input)
+extern "C" void forward_maxpool_layer_gpu(maxpool_layer layer, network_state state)
 {
-    int h = (layer.h-1)/layer.stride + 1;
-    int w = (layer.w-1)/layer.stride + 1;
+    int h = layer.out_h;
+    int w = layer.out_w;
     int c = layer.c;
 
     size_t n = h*w*c*layer.batch;
 
-    forward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, input, layer.output_gpu, layer.indexes_gpu);
+    forward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK, 0, get_cuda_stream()>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, state.input, layer.output_gpu, layer.indexes_gpu);
     check_error(cudaPeekAtLastError());
 }
 
-extern "C" void backward_maxpool_layer_gpu(maxpool_layer layer, float * delta)
+extern "C" void backward_maxpool_layer_gpu(maxpool_layer layer, network_state state)
 {
     size_t n = layer.h*layer.w*layer.c*layer.batch;
 
-    backward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.delta_gpu, delta, layer.indexes_gpu);
+    backward_maxpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.h, layer.w, layer.c, layer.stride, layer.size, layer.pad, layer.delta_gpu, state.delta, layer.indexes_gpu);
     check_error(cudaPeekAtLastError());
 }
 

--
Gitblit v1.10.0