From 043289426b2d08d925fc1c980b0d2a01e2360e93 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Sat, 04 Aug 2018 00:11:10 +0000
Subject: [PATCH] max pool layer is fixed

---
 src/maxpool_layer.c |  161 ++++++++++++++++++++++++++++++++---------------------
 1 files changed, 97 insertions(+), 64 deletions(-)

diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c
index 070eaba..928102f 100644
--- a/src/maxpool_layer.c
+++ b/src/maxpool_layer.c
@@ -1,95 +1,128 @@
 #include "maxpool_layer.h"
+#include "cuda.h"
 #include <stdio.h>
 
-image get_maxpool_image(maxpool_layer layer)
+image get_maxpool_image(maxpool_layer l)
 {
-    int h = (layer.h-1)/layer.stride + 1;
-    int w = (layer.w-1)/layer.stride + 1;
-    int c = layer.c;
-    return float_to_image(h,w,c,layer.output);
+    int h = l.out_h;
+    int w = l.out_w;
+    int c = l.c;
+    return float_to_image(w,h,c,l.output);
 }
 
-image get_maxpool_delta(maxpool_layer layer)
+image get_maxpool_delta(maxpool_layer l)
 {
-    int h = (layer.h-1)/layer.stride + 1;
-    int w = (layer.w-1)/layer.stride + 1;
-    int c = layer.c;
-    return float_to_image(h,w,c,layer.delta);
+    int h = l.out_h;
+    int w = l.out_w;
+    int c = l.c;
+    return float_to_image(w,h,c,l.delta);
 }
 
-maxpool_layer *make_maxpool_layer(int batch, int h, int w, int c, int size, int stride)
+maxpool_layer make_maxpool_layer(int batch, int h, int w, int c, int size, int stride, int padding)
 {
-    fprintf(stderr, "Maxpool Layer: %d x %d x %d image, %d size, %d stride\n", h,w,c,size,stride);
-    maxpool_layer *layer = calloc(1, sizeof(maxpool_layer));
-    layer->batch = batch;
-    layer->h = h;
-    layer->w = w;
-    layer->c = c;
-    layer->size = size;
-    layer->stride = stride;
-    layer->max_indexes = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(int));
-    layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float));
-    layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c*batch, sizeof(float));
-    return layer;
+    maxpool_layer l = {0};
+    l.type = MAXPOOL;
+    l.batch = batch;
+    l.h = h;
+    l.w = w;
+    l.c = c;
+    l.pad = padding;
+    l.out_w = (w + padding - size) / stride + 1;
+    l.out_h = (h + padding - size) / stride + 1;
+    l.out_c = c;
+    l.outputs = l.out_h * l.out_w * l.out_c;
+    l.inputs = h*w*c;
+    l.size = size;
+    l.stride = stride;
+    int output_size = l.out_h * l.out_w * l.out_c * batch;
+    l.indexes = calloc(output_size, sizeof(int));
+    l.output =  calloc(output_size, sizeof(float));
+    l.delta =   calloc(output_size, sizeof(float));
+    l.forward = forward_maxpool_layer;
+    l.backward = backward_maxpool_layer;
+    #ifdef GPU
+    l.forward_gpu = forward_maxpool_layer_gpu;
+    l.backward_gpu = backward_maxpool_layer_gpu;
+    l.indexes_gpu = cuda_make_int_array(output_size);
+    l.output_gpu  = cuda_make_array(l.output, output_size);
+    l.delta_gpu   = cuda_make_array(l.delta, output_size);
+    #endif
+	l.bflops = (l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
+    fprintf(stderr, "max          %d x %d / %d  %4d x%4d x%4d   ->  %4d x%4d x%4d %5.3f BF\n", size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
+    return l;
 }
 
-void resize_maxpool_layer(maxpool_layer *layer, int h, int w, int c)
+void resize_maxpool_layer(maxpool_layer *l, int w, int h)
 {
-    layer->h = h;
-    layer->w = w;
-    layer->c = c;
-    layer->output = realloc(layer->output, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch* sizeof(float));
-    layer->delta = realloc(layer->delta, ((h-1)/layer->stride+1) * ((w-1)/layer->stride+1) * c * layer->batch*sizeof(float));
+    l->h = h;
+    l->w = w;
+    l->inputs = h*w*l->c;
+
+    l->out_w = (w + l->pad - l->size) / l->stride + 1;
+    l->out_h = (h + l->pad - l->size) / l->stride + 1;
+    l->outputs = l->out_w * l->out_h * l->c;
+    int output_size = l->outputs * l->batch;
+
+    l->indexes = realloc(l->indexes, output_size * sizeof(int));
+    l->output = realloc(l->output, output_size * sizeof(float));
+    l->delta = realloc(l->delta, output_size * sizeof(float));
+
+    #ifdef GPU
+    cuda_free((float *)l->indexes_gpu);
+    cuda_free(l->output_gpu);
+    cuda_free(l->delta_gpu);
+    l->indexes_gpu = cuda_make_int_array(output_size);
+    l->output_gpu  = cuda_make_array(l->output, output_size);
+    l->delta_gpu   = cuda_make_array(l->delta,  output_size);
+    #endif
 }
 
-void forward_maxpool_layer(const maxpool_layer layer, float *input)
+void forward_maxpool_layer(const maxpool_layer l, network_state state)
 {
-    int b;
-    for(b = 0; b < layer.batch; ++b){
-        int h = (layer.h-1)/layer.stride + 1;
-        int w = (layer.w-1)/layer.stride + 1;
-        int c = layer.c;
+    int b,i,j,k,m,n;
+    int w_offset = -l.pad / l.stride;
+    int h_offset = -l.pad / l.stride;
 
-        int i,j,k,l,m;
-        for(k = 0; k < layer.c; ++k){
-            for(i = 0; i < layer.h; i += layer.stride){
-                for(j = 0; j < layer.w; j += layer.stride){
-                    int out_index = j/layer.stride + w*(i/layer.stride + h*(k + c*b));
-                    layer.output[out_index] = -FLT_MAX;
-                    int lower = (-layer.size-1)/2 + 1;
-                    int upper = layer.size/2 + 1;
+    int h = l.out_h;
+    int w = l.out_w;
+    int c = l.c;
 
-                    int lh = (i+lower < 0)       ? 0 : i+lower;
-                    int uh = (i+upper > layer.h) ? layer.h : i+upper;
-
-                    int lw = (j+lower < 0)       ? 0 : j+lower;
-                    int uw = (j+upper > layer.w) ? layer.w : j+upper;
-                    for(l = lh; l < uh; ++l){
-                        for(m = lw; m < uw; ++m){
-                            //printf("%d %d\n", l, m);
-                            int index = m + layer.w*(l + layer.h*(k + b*layer.c));
-                            if(input[index] > layer.output[out_index]){
-                                layer.output[out_index] = input[index];
-                                layer.max_indexes[out_index] = index;
-                            }
+    for(b = 0; b < l.batch; ++b){
+        for(k = 0; k < c; ++k){
+            for(i = 0; i < h; ++i){
+                for(j = 0; j < w; ++j){
+                    int out_index = j + w*(i + h*(k + c*b));
+                    float max = -FLT_MAX;
+                    int max_i = -1;
+                    for(n = 0; n < l.size; ++n){
+                        for(m = 0; m < l.size; ++m){
+                            int cur_h = h_offset + i*l.stride + n;
+                            int cur_w = w_offset + j*l.stride + m;
+                            int index = cur_w + l.w*(cur_h + l.h*(k + b*l.c));
+                            int valid = (cur_h >= 0 && cur_h < l.h &&
+                                         cur_w >= 0 && cur_w < l.w);
+                            float val = (valid != 0) ? state.input[index] : -FLT_MAX;
+                            max_i = (val > max) ? index : max_i;
+                            max   = (val > max) ? val   : max;
                         }
                     }
+                    l.output[out_index] = max;
+                    l.indexes[out_index] = max_i;
                 }
             }
         }
     }
 }
 
-void backward_maxpool_layer(const maxpool_layer layer, float *input, float *delta)
+void backward_maxpool_layer(const maxpool_layer l, network_state state)
 {
     int i;
-    int h = (layer.h-1)/layer.stride + 1;
-    int w = (layer.w-1)/layer.stride + 1;
-    int c = layer.c;
-    memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
-    for(i = 0; i < h*w*c*layer.batch; ++i){
-        int index = layer.max_indexes[i];
-        delta[index] += layer.delta[i];
+    int h = l.out_h;
+    int w = l.out_w;
+    int c = l.c;
+    for(i = 0; i < h*w*c*l.batch; ++i){
+        int index = l.indexes[i];
+        state.delta[index] += l.delta[i];
     }
 }
 

--
Gitblit v1.10.0