From 729ce43e6ec45cfdb58e06e227428a0f81c5de0f Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 10 Jun 2016 00:20:31 +0000
Subject: [PATCH] stuff

---
 src/convolutional_layer.c |  279 ++++++++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 250 insertions(+), 29 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 871a84e..af867e5 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -1,5 +1,6 @@
 #include "convolutional_layer.h"
 #include "utils.h"
+#include "batchnorm_layer.h"
 #include "im2col.h"
 #include "col2im.h"
 #include "blas.h"
@@ -7,6 +8,57 @@
 #include <stdio.h>
 #include <time.h>
 
+#ifdef AI2
+#include "xnor_layer.h"
+#endif
+
+#ifndef AI2
+#define AI2 0
+#endif
+
+void swap_binary(convolutional_layer *l)
+{
+    float *swap = l->filters;
+    l->filters = l->binary_filters;
+    l->binary_filters = swap;
+
+    #ifdef GPU
+    swap = l->filters_gpu;
+    l->filters_gpu = l->binary_filters_gpu;
+    l->binary_filters_gpu = swap;
+    #endif
+}
+
+void binarize_filters(float *filters, int n, int size, float *binary)
+{
+    int i, f;
+    for(f = 0; f < n; ++f){
+        float mean = 0;
+        for(i = 0; i < size; ++i){
+            mean += fabs(filters[f*size + i]);
+        }
+        mean = mean / size;
+        for(i = 0; i < size; ++i){
+            binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
+        }
+    }
+}
+
+void binarize_input(float *input, int n, int size, float *binary)
+{
+    int i, s;
+    for(s = 0; s < size; ++s){
+        float mean = 0;
+        for(i = 0; i < n; ++i){
+            mean += fabs(input[i*size + s]);
+        }
+        mean = mean / n;
+        for(i = 0; i < n; ++i){
+            binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
+        }
+    }
+}
+
 int convolutional_out_height(convolutional_layer l)
 {
     int h = l.h;
@@ -41,7 +93,41 @@
     return float_to_image(w,h,c,l.delta);
 }
 
-convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize)
+size_t get_workspace_size(layer l){
+#ifdef CUDNN
+    size_t most = 0;
+    size_t s = 0;
+    cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            l.fw_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            l.bf_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            l.bd_algo,
+            &s);
+    if (s > most) most = s;
+    return most;
+#else
+    return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
+#endif
+}
+
+convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
 {
     int i;
     convolutional_layer l = {0};
@@ -51,6 +137,8 @@
     l.w = w;
     l.c = c;
     l.n = n;
+    l.binary = binary;
+    l.xnor = xnor;
     l.batch = batch;
     l.stride = stride;
     l.size = size;
@@ -74,10 +162,19 @@
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = l.w * l.h * l.c;
 
-    l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
     l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
     l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
 
+    if(binary){
+        l.binary_filters = calloc(c*n*size*size, sizeof(float));
+        l.cfilters = calloc(c*n*size*size, sizeof(char));
+        l.scales = calloc(n, sizeof(float));
+    }
+    if(xnor){
+        l.binary_filters = calloc(c*n*size*size, sizeof(float));
+        l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
+    }
+
     if(batch_normalize){
         l.scales = calloc(n, sizeof(float));
         l.scale_updates = calloc(n, sizeof(float));
@@ -102,10 +199,17 @@
     l.scales_gpu = cuda_make_array(l.scales, n);
     l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
 
-    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
     l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
     l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
 
+    if(binary){
+        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+    }
+    if(xnor){
+        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
+    }
+
     if(batch_normalize){
         l.mean_gpu = cuda_make_array(l.mean, n);
         l.variance_gpu = cuda_make_array(l.variance, n);
@@ -119,7 +223,50 @@
         l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
         l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
     }
+#ifdef CUDNN
+    cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.filterDesc);
+    cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.dfilterDesc);
+    cudnnCreateConvolutionDescriptor(&l.convDesc);
+    cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+
+    cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+    int padding = l.pad ? l.size/2 : 0;
+    cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l.fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l.bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l.bf_algo);
 #endif
+#endif
+    l.workspace_size = get_workspace_size(l);
     l.activation = activation;
 
     fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@@ -141,7 +288,7 @@
 
 void test_convolutional_layer()
 {
-    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1);
+    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
     l.batch_normalize = 1;
     float data[] = {1,1,1,1,1,
         1,1,1,1,1,
@@ -176,22 +323,54 @@
     l->outputs = l->out_h * l->out_w * l->out_c;
     l->inputs = l->w * l->h * l->c;
 
-    l->col_image = realloc(l->col_image,
-            out_h*out_w*l->size*l->size*l->c*sizeof(float));
     l->output = realloc(l->output,
             l->batch*out_h * out_w * l->n*sizeof(float));
     l->delta  = realloc(l->delta,
             l->batch*out_h * out_w * l->n*sizeof(float));
 
 #ifdef GPU
-    cuda_free(l->col_image_gpu);
     cuda_free(l->delta_gpu);
     cuda_free(l->output_gpu);
 
-    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
     l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
     l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
+#ifdef CUDNN
+    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
+
+    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
+    int padding = l->pad ? l->size/2 : 0;
+    cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->filterDesc,
+            l->convDesc,
+            l->dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l->fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l->filterDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l->bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l->bf_algo);
 #endif
+#endif
+    l->workspace_size = get_workspace_size(*l);
 }
 
 void add_bias(float *output, float *biases, int batch, int n, int size)
@@ -228,43 +407,85 @@
     }
 }
 
-void forward_convolutional_layer(const convolutional_layer l, network_state state)
+void forward_convolutional_layer(convolutional_layer l, network_state state)
 {
     int out_h = convolutional_out_height(l);
     int out_w = convolutional_out_width(l);
     int i;
 
+
     fill_cpu(l.outputs*l.batch, 0, l.output, 1);
 
+    /*
+       if(l.binary){
+       binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+       binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+       swap_binary(&l);
+       }
+     */
+
+    /*
+       if(l.binary){
+       int m = l.n;
+       int k = l.size*l.size*l.c;
+       int n = out_h*out_w;
+
+       char  *a = l.cfilters;
+       float *b = state.workspace;
+       float *c = l.output;
+
+       for(i = 0; i < l.batch; ++i){
+       im2col_cpu(state.input, l.c, l.h, l.w, 
+       l.size, l.stride, l.pad, b);
+       gemm_bin(m,n,k,1,a,k,b,n,c,n);
+       c += n*m;
+       state.input += l.c*l.h*l.w;
+       }
+       scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+       add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
+       activate_array(l.output, m*n*l.batch, l.activation);
+       return;
+       }
+     */
+
+    if(l.xnor && (l.c%32 != 0 || !AI2)){
+        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+        swap_binary(&l);
+        for(i = 0; i < l.batch; ++i){
+            binarize_input(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input + i*l.inputs);
+        }
+        state.input = l.binary_input;
+    }
+
     int m = l.n;
     int k = l.size*l.size*l.c;
     int n = out_h*out_w;
 
-    float *a = l.filters;
-    float *b = l.col_image;
-    float *c = l.output;
+    if (l.xnor && l.c%32 == 0 && AI2) {
+        forward_xnor_layer(l, state);
+        printf("xnor\n");
+    } else {
 
-    for(i = 0; i < l.batch; ++i){
-        im2col_cpu(state.input, l.c, l.h, l.w, 
-                l.size, l.stride, l.pad, b);
-        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-        c += n*m;
-        state.input += l.c*l.h*l.w;
+        float *a = l.filters;
+        float *b = state.workspace;
+        float *c = l.output;
+
+        for(i = 0; i < l.batch; ++i){
+            im2col_cpu(state.input, l.c, l.h, l.w, 
+                    l.size, l.stride, l.pad, b);
+            gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+            c += n*m;
+            state.input += l.c*l.h*l.w;
+        }
     }
 
     if(l.batch_normalize){
-        if(state.train){
-            mean_cpu(l.output, l.batch, l.n, l.out_h*l.out_w, l.mean);   
-            variance_cpu(l.output, l.mean, l.batch, l.n, l.out_h*l.out_w, l.variance);   
-            normalize_cpu(l.output, l.mean, l.variance, l.batch, l.n, l.out_h*l.out_w);   
-        } else {
-            normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.n, l.out_h*l.out_w);
-        }
-        scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+        forward_batchnorm_layer(l, state);
     }
     add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
 
     activate_array(l.output, m*n*l.batch, l.activation);
+    if(l.binary || l.xnor) swap_binary(&l);
 }
 
 void backward_convolutional_layer(convolutional_layer l, network_state state)
@@ -280,7 +501,7 @@
 
     for(i = 0; i < l.batch; ++i){
         float *a = l.delta + i*m*k;
-        float *b = l.col_image;
+        float *b = state.workspace;
         float *c = l.filter_updates;
 
         float *im = state.input+i*l.c*l.h*l.w;
@@ -292,11 +513,11 @@
         if(state.delta){
             a = l.filters;
             b = l.delta + i*m*k;
-            c = l.col_image;
+            c = state.workspace;
 
             gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
 
-            col2im_cpu(l.col_image, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
+            col2im_cpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
         }
     }
 }

--
Gitblit v1.10.0