From d9f1b0b16edeb59281355a855e18a8be343fc33c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 08 Aug 2014 19:04:15 +0000
Subject: [PATCH] probably how maxpool layers should be

---
 src/convolutional_layer.c |  302 ++++++++++++++++++++++---------------------------
 1 files changed, 135 insertions(+), 167 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 40d5858..6c7f947 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -5,12 +5,18 @@
 
 int convolutional_out_height(convolutional_layer layer)
 {
-    return (layer.h-layer.size)/layer.stride + 1;
+    int h = layer.h;
+    if (!layer.pad) h -= layer.size;
+    else h -= 1;
+    return h/layer.stride + 1;
 }
 
 int convolutional_out_width(convolutional_layer layer)
 {
-    return (layer.w-layer.size)/layer.stride + 1;
+    int w = layer.w;
+    if (!layer.pad) w -= layer.size;
+    else w -= 1;
+    return w/layer.stride + 1;
 }
 
 image get_convolutional_image(convolutional_layer layer)
@@ -31,11 +37,16 @@
     return float_to_image(h,w,c,layer.delta);
 }
 
-convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
+convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay)
 {
     int i;
     size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter...
     convolutional_layer *layer = calloc(1, sizeof(convolutional_layer));
+
+    layer->learning_rate = learning_rate;
+    layer->momentum = momentum;
+    layer->decay = decay;
+
     layer->h = h;
     layer->w = w;
     layer->c = c;
@@ -43,6 +54,7 @@
     layer->batch = batch;
     layer->stride = stride;
     layer->size = size;
+    layer->pad = pad;
 
     layer->filters = calloc(c*n*size*size, sizeof(float));
     layer->filter_updates = calloc(c*n*size*size, sizeof(float));
@@ -52,10 +64,11 @@
     layer->bias_updates = calloc(n, sizeof(float));
     layer->bias_momentum = calloc(n, sizeof(float));
     float scale = 1./(size*size*c);
-    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform());
+    //scale = .0001;
+    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform()-.5);
     for(i = 0; i < n; ++i){
         //layer->biases[i] = rand_normal()*scale + scale;
-        layer->biases[i] = 0;
+        layer->biases[i] = .5;
     }
     int out_h = convolutional_out_height(*layer);
     int out_w = convolutional_out_width(*layer);
@@ -63,10 +76,22 @@
     layer->col_image = calloc(layer->batch*out_h*out_w*size*size*c, sizeof(float));
     layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float));
     layer->delta  = calloc(layer->batch*out_h * out_w * n, sizeof(float));
+    #ifdef GPU
+    layer->filters_cl = cl_make_array(layer->filters, c*n*size*size);
+    layer->filter_updates_cl = cl_make_array(layer->filter_updates, c*n*size*size);
+    layer->filter_momentum_cl = cl_make_array(layer->filter_momentum, c*n*size*size);
+
+    layer->biases_cl = cl_make_array(layer->biases, n);
+    layer->bias_updates_cl = cl_make_array(layer->bias_updates, n);
+    layer->bias_momentum_cl = cl_make_array(layer->bias_momentum, n);
+
+    layer->col_image_cl = cl_make_array(layer->col_image, layer->batch*out_h*out_w*size*size*c);
+    layer->delta_cl = cl_make_array(layer->delta, layer->batch*out_h*out_w*n);
+    layer->output_cl = cl_make_array(layer->output, layer->batch*out_h*out_w*n);
+    #endif
     layer->activation = activation;
 
     fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
-    srand(0);
 
     return layer;
 }
@@ -87,196 +112,118 @@
                                 layer->batch*out_h * out_w * layer->n*sizeof(float));
 }
 
-void forward_convolutional_layer(const convolutional_layer layer, float *in)
-{
-    int i;
-    int m = layer.n;
-    int k = layer.size*layer.size*layer.c;
-    int n = convolutional_out_height(layer)*
-            convolutional_out_width(layer)*
-            layer.batch;
-
-    memset(layer.output, 0, m*n*sizeof(float));
-
-    float *a = layer.filters;
-    float *b = layer.col_image;
-    float *c = layer.output;
-    for(i = 0; i < layer.batch; ++i){
-        im2col_cpu(in+i*(n/layer.batch),  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b+i*(n/layer.batch));
-    }
-    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-
-    for(i = 0; i < m*n; ++i){
-        layer.output[i] = activate(layer.output[i], layer.activation);
-    }
-    //for(i = 0; i < m*n; ++i) if(i%(m*n/10+1)==0) printf("%f, ", layer.output[i]); printf("\n");
-
-}
-
-void gradient_delta_convolutional_layer(convolutional_layer layer)
-{
-    int i;
-    int size = convolutional_out_height(layer)*
-                convolutional_out_width(layer)*
-                layer.n*
-                layer.batch;
-    for(i = 0; i < size; ++i){
-        layer.delta[i] *= gradient(layer.output[i], layer.activation);
-    }
-}
-
-void learn_bias_convolutional_layer(convolutional_layer layer)
+void bias_output(const convolutional_layer layer)
 {
     int i,j,b;
-    int size = convolutional_out_height(layer)
-                *convolutional_out_width(layer);
+    int out_h = convolutional_out_height(layer);
+    int out_w = convolutional_out_width(layer);
     for(b = 0; b < layer.batch; ++b){
         for(i = 0; i < layer.n; ++i){
-            float sum = 0;
-            for(j = 0; j < size; ++j){
-                sum += layer.delta[j+size*(i+b*layer.n)];
+            for(j = 0; j < out_h*out_w; ++j){
+                layer.output[(b*layer.n + i)*out_h*out_w + j] = layer.biases[i];
             }
-            layer.bias_updates[i] += sum/size;
         }
     }
 }
 
-void learn_convolutional_layer(convolutional_layer layer)
+void forward_convolutional_layer(const convolutional_layer layer, float *in)
 {
-    gradient_delta_convolutional_layer(layer);
-    learn_bias_convolutional_layer(layer);
+    int out_h = convolutional_out_height(layer);
+    int out_w = convolutional_out_width(layer);
+    int i;
+
+    bias_output(layer);
+
     int m = layer.n;
-    int n = layer.size*layer.size*layer.c;
-    int k = convolutional_out_height(layer)*
-            convolutional_out_width(layer)*
-            layer.batch;
+    int k = layer.size*layer.size*layer.c;
+    int n = out_h*out_w;
 
-    float *a = layer.delta;
+    float *a = layer.filters;
     float *b = layer.col_image;
-    float *c = layer.filter_updates;
+    float *c = layer.output;
 
-    gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
+    im2col_cpu(in, layer.batch, layer.c, layer.h, layer.w, 
+        layer.size, layer.stride, layer.pad, b);
+
+    for(i = 0; i < layer.batch; ++i){
+        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+        c += n*m;
+        in += layer.h*layer.w*layer.c;
+        b += k*n;
+    }
+    /*
+    int i;
+    for(i = 0; i < m*n; ++i) printf("%f, ", layer.output[i]);
+    printf("\n");
+    */
+    activate_array(layer.output, m*n*layer.batch, layer.activation);
+}
+
+void learn_bias_convolutional_layer(convolutional_layer layer)
+{
+    int i,b;
+    int size = convolutional_out_height(layer)
+        *convolutional_out_width(layer);
+    for(b = 0; b < layer.batch; ++b){
+        for(i = 0; i < layer.n; ++i){
+            layer.bias_updates[i] += mean_array(layer.delta+size*(i+b*layer.n), size);
+        }
+    }
 }
 
 void backward_convolutional_layer(convolutional_layer layer, float *delta)
 {
     int i;
-    int m = layer.size*layer.size*layer.c;
-    int k = layer.n;
-    int n = convolutional_out_height(layer)*
-            convolutional_out_width(layer)*
-            layer.batch;
+    int m = layer.n;
+    int n = layer.size*layer.size*layer.c;
+    int k = convolutional_out_height(layer)*
+        convolutional_out_width(layer);
+    gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
+    learn_bias_convolutional_layer(layer);
 
-    float *a = layer.filters;
-    float *b = layer.delta;
-    float *c = layer.col_image;
+    float *a = layer.delta;
+    float *b = layer.col_image;
+    float *c = layer.filter_updates;
 
-
-    memset(c, 0, m*n*sizeof(float));
-    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
-
-    memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
     for(i = 0; i < layer.batch; ++i){
-        col2im_cpu(c+i*n/layer.batch,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, delta+i*n/layer.batch);
+        gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
+        a += m*k;
+        b += k*n;
+    }
+
+    if(delta){
+        m = layer.size*layer.size*layer.c;
+        k = layer.n;
+        n = convolutional_out_height(layer)*
+            convolutional_out_width(layer);
+
+        a = layer.filters;
+        b = layer.delta;
+        c = layer.col_image;
+
+        memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
+
+        for(i = 0; i < layer.batch; ++i){
+            gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
+            col2im_cpu(c, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, delta);
+            c += k*n;
+            delta += layer.h*layer.w*layer.c;
+        }
     }
 }
 
-void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
+void update_convolutional_layer(convolutional_layer layer)
 {
-    int i;
     int size = layer.size*layer.size*layer.c*layer.n;
-    for(i = 0; i < layer.n; ++i){
-        layer.biases[i] += step*layer.bias_updates[i];
-        layer.bias_updates[i] *= momentum;
-    }
-    for(i = 0; i < size; ++i){
-        layer.filters[i] += step*(layer.filter_updates[i] - decay*layer.filters[i]);
-        layer.filter_updates[i] *= momentum;
-    }
-}
-/*
+    axpy_cpu(layer.n, layer.learning_rate, layer.bias_updates, 1, layer.biases, 1);
+    scal_cpu(layer.n,layer.momentum, layer.bias_updates, 1);
 
-void backward_convolutional_layer2(convolutional_layer layer, float *input, float *delta)
-{
-    image in_delta = float_to_image(layer.h, layer.w, layer.c, delta);
-    image out_delta = get_convolutional_delta(layer);
-    int i,j;
-    for(i = 0; i < layer.n; ++i){
-        rotate_image(layer.kernels[i]);
-    }
-
-    zero_image(in_delta);
-    upsample_image(out_delta, layer.stride, layer.upsampled);
-    for(j = 0; j < in_delta.c; ++j){
-        for(i = 0; i < layer.n; ++i){
-            two_d_convolve(layer.upsampled, i, layer.kernels[i], j, 1, in_delta, j, layer.edge);
-        }
-    }
-
-    for(i = 0; i < layer.n; ++i){
-        rotate_image(layer.kernels[i]);
-    }
+    scal_cpu(size, 1.-layer.learning_rate*layer.decay, layer.filters, 1);
+    axpy_cpu(size, layer.learning_rate, layer.filter_updates, 1, layer.filters, 1);
+    scal_cpu(size, layer.momentum, layer.filter_updates, 1);
 }
 
 
-void learn_convolutional_layer(convolutional_layer layer, float *input)
-{
-    int i;
-    image in_image = float_to_image(layer.h, layer.w, layer.c, input);
-    image out_delta = get_convolutional_delta(layer);
-    gradient_delta_convolutional_layer(layer);
-    for(i = 0; i < layer.n; ++i){
-        kernel_update(in_image, layer.kernel_updates[i], layer.stride, i, out_delta, layer.edge);
-        layer.bias_updates[i] += avg_image_layer(out_delta, i);
-    }
-}
-
-void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
-{
-    int i,j;
-    for(i = 0; i < layer.n; ++i){
-        layer.bias_momentum[i] = step*(layer.bias_updates[i]) 
-                                + momentum*layer.bias_momentum[i];
-        layer.biases[i] += layer.bias_momentum[i];
-        layer.bias_updates[i] = 0;
-        int pixels = layer.kernels[i].h*layer.kernels[i].w*layer.kernels[i].c;
-        for(j = 0; j < pixels; ++j){
-            layer.kernel_momentum[i].data[j] = step*(layer.kernel_updates[i].data[j] - decay*layer.kernels[i].data[j]) 
-                                                + momentum*layer.kernel_momentum[i].data[j];
-            layer.kernels[i].data[j] += layer.kernel_momentum[i].data[j];
-        }
-        zero_image(layer.kernel_updates[i]);
-    }
-}
-*/
-
-void test_convolutional_layer()
-{
-    convolutional_layer l = *make_convolutional_layer(1,4,4,1,1,3,1,LINEAR);
-    float input[] =    {1,2,3,4,
-                        5,6,7,8,
-                        9,10,11,12,
-                        13,14,15,16};
-    float filter[] =   {.5, 0, .3,
-                        0  , 1,  0,
-                        .2 , 0,  1};
-    float delta[] =    {1, 2,
-                        3,  4};
-    float in_delta[] = {.5,1,.3,.6,
-                        5,6,7,8,
-                        9,10,11,12,
-                        13,14,15,16};
-    l.filters = filter;
-    forward_convolutional_layer(l, input);
-    l.delta = delta;
-    learn_convolutional_layer(l);
-    image filter_updates = float_to_image(3,3,1,l.filter_updates);
-    print_image(filter_updates);
-    printf("Delta:\n");
-    backward_convolutional_layer(l, in_delta);
-    pm(4,4,in_delta);
-}
-
 image get_convolutional_filter(convolutional_layer layer, int i)
 {
     int h = layer.size;
@@ -320,12 +267,33 @@
     image *single_filters = weighted_sum_filters(layer, 0);
     show_images(single_filters, layer.n, window);
 
-    image delta = get_convolutional_delta(layer);
+    image delta = get_convolutional_image(layer);
     image dc = collapse_image_layers(delta, 1);
     char buff[256];
-    sprintf(buff, "%s: Delta", window);
+    sprintf(buff, "%s: Output", window);
     //show_image(dc, buff);
+    //save_image(dc, buff);
     free_image(dc);
     return single_filters;
 }
 
+#ifdef GPU
+void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in)
+{
+    int m = layer.n;
+    int k = layer.size*layer.size*layer.c;
+    int n = convolutional_out_height(layer)*
+        convolutional_out_width(layer)*
+        layer.batch;
+
+    cl_write_array(layer.filters_cl, layer.filters, m*k);
+    cl_mem a = layer.filters_cl;
+    cl_mem b = layer.col_image_cl;
+    cl_mem c = layer.output_cl;
+    im2col_ongpu(in, layer.batch, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b);
+    gemm_ongpu(0,0,m,n,k,1,a,k,b,n,0,c,n);
+    activate_array_ongpu(layer.output_cl, m*n, layer.activation);
+    cl_read_array(layer.output_cl, layer.output, m*n);
+}
+#endif
+

--
Gitblit v1.10.0