From 4625a16ffdcf3b9f7bfc37046e70f4ecb87234ab Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 06 Jun 2016 20:22:45 +0000
Subject: [PATCH] tactics

---
 src/convolutional_layer.c |  143 ++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 126 insertions(+), 17 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index d76dfcd..c377802 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -88,6 +88,40 @@
     return float_to_image(w,h,c,l.delta);
 }
 
+size_t get_workspace_size(layer l){
+    #ifdef CUDNN
+    size_t most = 0;
+    size_t s = 0;
+    cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            l.fw_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            l.bf_algo,
+            &s);
+    if (s > most) most = s;
+    cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            l.bd_algo,
+            &s);
+    if (s > most) most = s;
+    return most;
+    #else
+    return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
+    #endif
+}
+
 convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
 {
     int i;
@@ -122,7 +156,6 @@
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = l.w * l.h * l.c;
 
-    l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
     l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
     l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
 
@@ -156,7 +189,6 @@
     l.scales_gpu = cuda_make_array(l.scales, n);
     l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
 
-    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
     l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
     l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
 
@@ -182,7 +214,50 @@
         l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
         l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
     }
+#ifdef CUDNN
+    cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.filterDesc);
+    cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+    cudnnCreateFilterDescriptor(&l.dfilterDesc);
+    cudnnCreateConvolutionDescriptor(&l.convDesc);
+    cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+
+    cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w); 
+    cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w); 
+    cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size); 
+    int padding = l.pad ? l.size/2 : 0;
+    cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.filterDesc,
+            l.convDesc,
+            l.dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l.fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l.filterDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l.bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l.srcTensorDesc,
+            l.ddstTensorDesc,
+            l.convDesc,
+            l.dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l.bf_algo);
 #endif
+#endif
+    l.workspace_size = get_workspace_size(l);
     l.activation = activation;
 
     fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@@ -239,22 +314,54 @@
     l->outputs = l->out_h * l->out_w * l->out_c;
     l->inputs = l->w * l->h * l->c;
 
-    l->col_image = realloc(l->col_image,
-            out_h*out_w*l->size*l->size*l->c*sizeof(float));
     l->output = realloc(l->output,
             l->batch*out_h * out_w * l->n*sizeof(float));
     l->delta  = realloc(l->delta,
             l->batch*out_h * out_w * l->n*sizeof(float));
 
 #ifdef GPU
-    cuda_free(l->col_image_gpu);
     cuda_free(l->delta_gpu);
     cuda_free(l->output_gpu);
 
-    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
     l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
     l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
+    #ifdef CUDNN
+    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
+
+    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
+    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
+    cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
+    int padding = l->pad ? l->size/2 : 0;
+    cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
+    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->filterDesc,
+            l->convDesc,
+            l->dstTensorDesc,
+            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            0,
+            &l->fw_algo);
+    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+            l->filterDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dsrcTensorDesc,
+            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            0,
+            &l->bd_algo);
+    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+            l->srcTensorDesc,
+            l->ddstTensorDesc,
+            l->convDesc,
+            l->dfilterDesc,
+            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            0,
+            &l->bf_algo);
+    #endif
 #endif
+    l->workspace_size = get_workspace_size(*l);
 }
 
 void add_bias(float *output, float *biases, int batch, int n, int size)
@@ -299,20 +406,21 @@
 
     fill_cpu(l.outputs*l.batch, 0, l.output, 1);
     /*
-    if(l.binary){
-        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
-        binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
-        swap_binary(&l);
-    }
-    */
+       if(l.binary){
+       binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+       binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+       swap_binary(&l);
+       }
+     */
 
+/*
     if(l.binary){
         int m = l.n;
         int k = l.size*l.size*l.c;
         int n = out_h*out_w;
 
         char  *a = l.cfilters;
-        float *b = l.col_image;
+        float *b = state.workspace;
         float *c = l.output;
 
         for(i = 0; i < l.batch; ++i){
@@ -327,13 +435,14 @@
         activate_array(l.output, m*n*l.batch, l.activation);
         return;
     }
+    */
 
     int m = l.n;
     int k = l.size*l.size*l.c;
     int n = out_h*out_w;
 
     float *a = l.filters;
-    float *b = l.col_image;
+    float *b = state.workspace;
     float *c = l.output;
 
     for(i = 0; i < l.batch; ++i){
@@ -365,7 +474,7 @@
 
     for(i = 0; i < l.batch; ++i){
         float *a = l.delta + i*m*k;
-        float *b = l.col_image;
+        float *b = state.workspace;
         float *c = l.filter_updates;
 
         float *im = state.input+i*l.c*l.h*l.w;
@@ -377,11 +486,11 @@
         if(state.delta){
             a = l.filters;
             b = l.delta + i*m*k;
-            c = l.col_image;
+            c = state.workspace;
 
             gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
 
-            col2im_cpu(l.col_image, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
+            col2im_cpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
         }
     }
 }

--
Gitblit v1.10.0