From 0e610b056dbcd85affa23f64f9f8da4d197f110a Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 08 Sep 2016 05:46:10 +0000
Subject: [PATCH] and again

---
 src/convolutional_layer.c |  355 +++++++++++++++++++++++++++++++++++++++++------------------
 1 files changed, 246 insertions(+), 109 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 159951d..ad2d8a5 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -1,5 +1,6 @@
 #include "convolutional_layer.h"
 #include "utils.h"
+#include "batchnorm_layer.h"
 #include "im2col.h"
 #include "col2im.h"
 #include "blas.h"
@@ -7,20 +8,74 @@
 #include <stdio.h>
 #include <time.h>
 
+#ifdef AI2
+#include "xnor_layer.h"
+#endif
+
+#ifndef AI2
+#define AI2 0
+void forward_xnor_layer(layer l, network_state state);
+#endif
+
+void swap_binary(convolutional_layer *l)
+{
+    float *swap = l->filters;
+    l->filters = l->binary_filters;
+    l->binary_filters = swap;
+
+    #ifdef GPU
+    swap = l->filters_gpu;
+    l->filters_gpu = l->binary_filters_gpu;
+    l->binary_filters_gpu = swap;
+    #endif
+}
+
+void binarize_filters(float *filters, int n, int size, float *binary)
+{
+    int i, f;
+    for(f = 0; f < n; ++f){
+        float mean = 0;
+        for(i = 0; i < size; ++i){
+            mean += fabs(filters[f*size + i]);
+        }
+        mean = mean / size;
+        for(i = 0; i < size; ++i){
+            binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
+        }
+    }
+}
+
+void binarize_cpu(float *input, int n, float *binary)
+{
+    int i;
+    for(i = 0; i < n; ++i){
+        binary[i] = (input[i] > 0) ? 1 : -1;
+    }
+}
+
+void binarize_input(float *input, int n, int size, float *binary)
+{
+    int i, s;
+    for(s = 0; s < size; ++s){
+        float mean = 0;
+        for(i = 0; i < n; ++i){
+            mean += fabs(input[i*size + s]);
+        }
+        mean = mean / n;
+        for(i = 0; i < n; ++i){
+            binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
+        }
+    }
+}
+
 int convolutional_out_height(convolutional_layer l)
 {
-    int h = l.h;
-    if (!l.pad) h -= l.size;
-    else h -= 1;
-    return h/l.stride + 1;
+    return (l.h + 2*l.pad - l.size) / l.stride + 1;
 }
 
 int convolutional_out_width(convolutional_layer l)
 {
-    int w = l.w;
-    if (!l.pad) w -= l.size;
-    else w -= 1;
-    return w/l.stride + 1;
+    return (l.w + 2*l.pad - l.size) / l.stride + 1;
 }
 
 image get_convolutional_image(convolutional_layer l)
@@ -41,65 +96,82 @@
     return float_to_image(w,h,c,l.delta);
 }
 
-void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
-    int i,b,f;
-    for(f = 0; f < n; ++f){
-        float sum = 0;
-        for(b = 0; b < batch; ++b){
-            for(i = 0; i < size; ++i){
-                int index = i + size*(f + n*b);
-                sum += delta[index] * x_norm[index];
-            }
-        }
-        scale_updates[f] += sum;
+size_t get_workspace_size(layer l){
+#ifdef CUDNN
+    if(gpu_index >= 0){
+        size_t most = 0;
+        size_t s = 0;
+        cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+                l.srcTensorDesc,
+                l.filterDesc,
+                l.convDesc,
+                l.dstTensorDesc,
+                l.fw_algo,
+                &s);
+        if (s > most) most = s;
+        cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+                l.srcTensorDesc,
+                l.ddstTensorDesc,
+                l.convDesc,
+                l.dfilterDesc,
+                l.bf_algo,
+                &s);
+        if (s > most) most = s;
+        cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+                l.filterDesc,
+                l.ddstTensorDesc,
+                l.convDesc,
+                l.dsrcTensorDesc,
+                l.bd_algo,
+                &s);
+        if (s > most) most = s;
+        return most;
     }
+    #endif
+    return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
 }
 
-void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+#ifdef GPU
+#ifdef CUDNN
+void cudnn_convolutional_setup(layer *l)
 {
+    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
 
-    int i,j,k;
-    for(i = 0; i < filters; ++i){
-        mean_delta[i] = 0;
-        for (j = 0; j < batch; ++j) {
-            for (k = 0; k < spatial; ++k) {
-                int index = j*filters*spatial + i*spatial + k;
-                mean_delta[i] += delta[index];
-            }
-        }
-        mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
-    }
+    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
+    cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->filterDesc,
+            l->convDesc,
+            l->dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l->fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l->filterDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l->bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l->bf_algo);
 }
-void  variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
-{
+#endif
+#endif
 
-    int i,j,k;
-    for(i = 0; i < filters; ++i){
-        variance_delta[i] = 0;
-        for(j = 0; j < batch; ++j){
-            for(k = 0; k < spatial; ++k){
-                int index = j*filters*spatial + i*spatial + k;
-                variance_delta[i] += delta[index]*(x[index] - mean[i]);
-            }
-        }
-        variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
-    }
-}
-void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
-{
-    int f, j, k;
-    for(j = 0; j < batch; ++j){
-        for(f = 0; f < filters; ++f){
-            for(k = 0; k < spatial; ++k){
-                int index = j*filters*spatial + f*spatial + k;
-                delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
-            }
-        }
-    }
-}
-
-convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary)
+convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor)
 {
     int i;
     convolutional_layer l = {0};
@@ -110,10 +182,11 @@
     l.c = c;
     l.n = n;
     l.binary = binary;
+    l.xnor = xnor;
     l.batch = batch;
     l.stride = stride;
     l.size = size;
-    l.pad = pad;
+    l.pad = padding;
     l.batch_normalize = batch_normalize;
 
     l.filters = calloc(c*n*size*size, sizeof(float));
@@ -133,12 +206,17 @@
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = l.w * l.h * l.c;
 
-    l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
     l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
     l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
 
     if(binary){
         l.binary_filters = calloc(c*n*size*size, sizeof(float));
+        l.cfilters = calloc(c*n*size*size, sizeof(char));
+        l.scales = calloc(n, sizeof(float));
+    }
+    if(xnor){
+        l.binary_filters = calloc(c*n*size*size, sizeof(float));
+        l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
     }
 
     if(batch_normalize){
@@ -156,37 +234,53 @@
     }
 
 #ifdef GPU
-    l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
-    l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
+    if(gpu_index >= 0){
+        l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
 
-    l.biases_gpu = cuda_make_array(l.biases, n);
-    l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
+        l.biases_gpu = cuda_make_array(l.biases, n);
+        l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
 
-    l.scales_gpu = cuda_make_array(l.scales, n);
-    l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
+        l.scales_gpu = cuda_make_array(l.scales, n);
+        l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
 
-    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
-    l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
-    l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+        l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
+        l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
 
-    if(binary){
-        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
-    }
+        if(binary){
+            l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        }
+        if(xnor){
+            l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+            l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
+        }
 
-    if(batch_normalize){
-        l.mean_gpu = cuda_make_array(l.mean, n);
-        l.variance_gpu = cuda_make_array(l.variance, n);
+        if(batch_normalize){
+            l.mean_gpu = cuda_make_array(l.mean, n);
+            l.variance_gpu = cuda_make_array(l.variance, n);
 
-        l.rolling_mean_gpu = cuda_make_array(l.mean, n);
-        l.rolling_variance_gpu = cuda_make_array(l.variance, n);
+            l.rolling_mean_gpu = cuda_make_array(l.mean, n);
+            l.rolling_variance_gpu = cuda_make_array(l.variance, n);
 
-        l.mean_delta_gpu = cuda_make_array(l.mean, n);
-        l.variance_delta_gpu = cuda_make_array(l.variance, n);
+            l.mean_delta_gpu = cuda_make_array(l.mean, n);
+            l.variance_delta_gpu = cuda_make_array(l.variance, n);
 
-        l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
-        l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+            l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+            l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+        }
+#ifdef CUDNN
+        cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+        cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+        cudnnCreateFilterDescriptor(&l.filterDesc);
+        cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+        cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+        cudnnCreateFilterDescriptor(&l.dfilterDesc);
+        cudnnCreateConvolutionDescriptor(&l.convDesc);
+        cudnn_convolutional_setup(&l);
+#endif
     }
 #endif
+    l.workspace_size = get_workspace_size(l);
     l.activation = activation;
 
     fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@@ -203,12 +297,15 @@
             l.filters[i*l.c*l.size*l.size + j] *= scale;
         }
         l.biases[i] -= l.rolling_mean[i] * scale;
+        l.scales[i] = 1;
+        l.rolling_mean[i] = 0;
+        l.rolling_variance[i] = 1;
     }
 }
 
 void test_convolutional_layer()
 {
-    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0);
+    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
     l.batch_normalize = 1;
     float data[] = {1,1,1,1,1,
         1,1,1,1,1,
@@ -243,22 +340,22 @@
     l->outputs = l->out_h * l->out_w * l->out_c;
     l->inputs = l->w * l->h * l->c;
 
-    l->col_image = realloc(l->col_image,
-            out_h*out_w*l->size*l->size*l->c*sizeof(float));
     l->output = realloc(l->output,
             l->batch*out_h * out_w * l->n*sizeof(float));
     l->delta  = realloc(l->delta,
             l->batch*out_h * out_w * l->n*sizeof(float));
 
 #ifdef GPU
-    cuda_free(l->col_image_gpu);
     cuda_free(l->delta_gpu);
     cuda_free(l->output_gpu);
 
-    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
     l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
     l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
+#ifdef CUDNN
+    cudnn_convolutional_setup(l);
 #endif
+#endif
+    l->workspace_size = get_workspace_size(*l);
 }
 
 void add_bias(float *output, float *biases, int batch, int n, int size)
@@ -295,43 +392,83 @@
     }
 }
 
-void forward_convolutional_layer(const convolutional_layer l, network_state state)
+void forward_convolutional_layer(convolutional_layer l, network_state state)
 {
     int out_h = convolutional_out_height(l);
     int out_w = convolutional_out_width(l);
     int i;
 
+
     fill_cpu(l.outputs*l.batch, 0, l.output, 1);
 
+    /*
+       if(l.binary){
+       binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+       binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+       swap_binary(&l);
+       }
+     */
+
+    /*
+       if(l.binary){
+       int m = l.n;
+       int k = l.size*l.size*l.c;
+       int n = out_h*out_w;
+
+       char  *a = l.cfilters;
+       float *b = state.workspace;
+       float *c = l.output;
+
+       for(i = 0; i < l.batch; ++i){
+       im2col_cpu(state.input, l.c, l.h, l.w, 
+       l.size, l.stride, l.pad, b);
+       gemm_bin(m,n,k,1,a,k,b,n,c,n);
+       c += n*m;
+       state.input += l.c*l.h*l.w;
+       }
+       scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+       add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
+       activate_array(l.output, m*n*l.batch, l.activation);
+       return;
+       }
+     */
+
+    if(l.xnor){
+        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+        swap_binary(&l);
+        binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
+        state.input = l.binary_input;
+    }
+
     int m = l.n;
     int k = l.size*l.size*l.c;
     int n = out_h*out_w;
 
-    float *a = l.filters;
-    float *b = l.col_image;
-    float *c = l.output;
+    if (l.xnor && l.c%32 == 0 && AI2) {
+        forward_xnor_layer(l, state);
+        printf("xnor\n");
+    } else {
 
-    for(i = 0; i < l.batch; ++i){
-        im2col_cpu(state.input, l.c, l.h, l.w, 
-                l.size, l.stride, l.pad, b);
-        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-        c += n*m;
-        state.input += l.c*l.h*l.w;
+        float *a = l.filters;
+        float *b = state.workspace;
+        float *c = l.output;
+
+        for(i = 0; i < l.batch; ++i){
+            im2col_cpu(state.input, l.c, l.h, l.w, 
+                    l.size, l.stride, l.pad, b);
+            gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+            c += n*m;
+            state.input += l.c*l.h*l.w;
+        }
     }
 
     if(l.batch_normalize){
-        if(state.train){
-            mean_cpu(l.output, l.batch, l.n, l.out_h*l.out_w, l.mean);   
-            variance_cpu(l.output, l.mean, l.batch, l.n, l.out_h*l.out_w, l.variance);   
-            normalize_cpu(l.output, l.mean, l.variance, l.batch, l.n, l.out_h*l.out_w);   
-        } else {
-            normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.n, l.out_h*l.out_w);
-        }
-        scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+        forward_batchnorm_layer(l, state);
     }
     add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
 
     activate_array(l.output, m*n*l.batch, l.activation);
+    if(l.binary || l.xnor) swap_binary(&l);
 }
 
 void backward_convolutional_layer(convolutional_layer l, network_state state)
@@ -347,7 +484,7 @@
 
     for(i = 0; i < l.batch; ++i){
         float *a = l.delta + i*m*k;
-        float *b = l.col_image;
+        float *b = state.workspace;
         float *c = l.filter_updates;
 
         float *im = state.input+i*l.c*l.h*l.w;
@@ -359,11 +496,11 @@
         if(state.delta){
             a = l.filters;
             b = l.delta + i*m*k;
-            c = l.col_image;
+            c = state.workspace;
 
             gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
 
-            col2im_cpu(l.col_image, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
+            col2im_cpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
         }
     }
 }

--
Gitblit v1.10.0