From c62b4f35aa2c59d7db0fd177affeed14b1ba4bcb Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 08 Sep 2016 07:04:39 +0000
Subject: [PATCH] adding coco models

---
 src/maxpool_layer.c |   29 +++++++++++++++--------------
 1 files changed, 15 insertions(+), 14 deletions(-)

diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c
index ef06175..3e0ea15 100644
--- a/src/maxpool_layer.c
+++ b/src/maxpool_layer.c
@@ -18,7 +18,7 @@
     return float_to_image(w,h,c,l.delta);
 }
 
-maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride)
+maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding)
 {
     fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d size, %d stride\n", h,w,c,size,stride);
     maxpool_layer l = {0};
@@ -27,8 +27,9 @@
     l.h = h;
     l.w = w;
     l.c = c;
-    l.out_w = (w-1)/stride + 1;
-    l.out_h = (h-1)/stride + 1;
+    l.pad = padding;
+    l.out_w = (w + 2*padding - size + 1)/stride + 1;
+    l.out_h = (h + 2*padding - size + 1)/stride + 1;
     l.out_c = c;
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = h*w*c;
@@ -48,12 +49,12 @@
 
 void resize_maxpool_layer(maxpool_layer *l, int w, int h)
 {
-    int stride = l->stride;
     l->h = h;
     l->w = w;
+    l->inputs = h*w*l->c;
 
-    l->out_w = (w-1)/stride + 1;
-    l->out_h = (h-1)/stride + 1;
+    l->out_w = (w + 2*l->pad - l->size + 1)/l->stride + 1;
+    l->out_h = (h + 2*l->pad - l->size + 1)/l->stride + 1;
     l->outputs = l->out_w * l->out_h * l->c;
     int output_size = l->outputs * l->batch;
 
@@ -66,19 +67,19 @@
     cuda_free(l->output_gpu);
     cuda_free(l->delta_gpu);
     l->indexes_gpu = cuda_make_int_array(output_size);
-    l->output_gpu  = cuda_make_array(0, output_size);
-    l->delta_gpu   = cuda_make_array(0, output_size);
+    l->output_gpu  = cuda_make_array(l->output, output_size);
+    l->delta_gpu   = cuda_make_array(l->delta,  output_size);
     #endif
 }
 
 void forward_maxpool_layer(const maxpool_layer l, network_state state)
 {
     int b,i,j,k,m,n;
-    int w_offset = (-l.size-1)/2 + 1;
-    int h_offset = (-l.size-1)/2 + 1;
+    int w_offset = -l.pad;
+    int h_offset = -l.pad;
 
-    int h = (l.h-1)/l.stride + 1;
-    int w = (l.w-1)/l.stride + 1;
+    int h = l.out_h;
+    int w = l.out_w;
     int c = l.c;
 
     for(b = 0; b < l.batch; ++b){
@@ -111,8 +112,8 @@
 void backward_maxpool_layer(const maxpool_layer l, network_state state)
 {
     int i;
-    int h = (l.h-1)/l.stride + 1;
-    int w = (l.w-1)/l.stride + 1;
+    int h = l.out_h;
+    int w = l.out_w;
     int c = l.c;
     for(i = 0; i < h*w*c*l.batch; ++i){
         int index = l.indexes[i];

--
Gitblit v1.10.0