From ec3d050a76ee8c41f35c4531d3fa07a2d9c28ed3 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 02 Jun 2016 22:25:24 +0000
Subject: [PATCH] hope i didn't break anything
---
src/convolutional_layer.c | 214 ++++++++++++++++++++++++++++++++--------------------
1 files changed, 131 insertions(+), 83 deletions(-)
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index cdc8bd3..5575aac 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -1,5 +1,6 @@
#include "convolutional_layer.h"
#include "utils.h"
+#include "batchnorm_layer.h"
#include "im2col.h"
#include "col2im.h"
#include "blas.h"
@@ -87,65 +88,41 @@
return float_to_image(w,h,c,l.delta);
}
-void backward_scale_cpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
- int i,b,f;
- for(f = 0; f < n; ++f){
- float sum = 0;
- for(b = 0; b < batch; ++b){
- for(i = 0; i < size; ++i){
- int index = i + size*(f + n*b);
- sum += delta[index] * x_norm[index];
- }
- }
- scale_updates[f] += sum;
- }
+size_t get_workspace_size(layer l){
+ #ifdef CUDNN
+ size_t most = 0;
+ size_t s = 0;
+ cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+ l.srcTensorDesc,
+ l.filterDesc,
+ l.convDesc,
+ l.dstTensorDesc,
+ l.fw_algo,
+ &s);
+ if (s > most) most = s;
+ cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+ l.srcTensorDesc,
+ l.ddstTensorDesc,
+ l.convDesc,
+ l.dfilterDesc,
+ l.bf_algo,
+ &s);
+ if (s > most) most = s;
+ cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+ l.filterDesc,
+ l.ddstTensorDesc,
+ l.convDesc,
+ l.dsrcTensorDesc,
+ l.bd_algo,
+ &s);
+ if (s > most) most = s;
+ return most;
+ #else
+ return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
+ #endif
}
-void mean_delta_cpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
-{
-
- int i,j,k;
- for(i = 0; i < filters; ++i){
- mean_delta[i] = 0;
- for (j = 0; j < batch; ++j) {
- for (k = 0; k < spatial; ++k) {
- int index = j*filters*spatial + i*spatial + k;
- mean_delta[i] += delta[index];
- }
- }
- mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
- }
-}
-void variance_delta_cpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
-{
-
- int i,j,k;
- for(i = 0; i < filters; ++i){
- variance_delta[i] = 0;
- for(j = 0; j < batch; ++j){
- for(k = 0; k < spatial; ++k){
- int index = j*filters*spatial + i*spatial + k;
- variance_delta[i] += delta[index]*(x[index] - mean[i]);
- }
- }
- variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
- }
-}
-void normalize_delta_cpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
-{
- int f, j, k;
- for(j = 0; j < batch; ++j){
- for(f = 0; f < filters; ++f){
- for(k = 0; k < spatial; ++k){
- int index = j*filters*spatial + f*spatial + k;
- delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
- }
- }
- }
-}
-
-convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary)
+convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
{
int i;
convolutional_layer l = {0};
@@ -179,7 +156,6 @@
l.outputs = l.out_h * l.out_w * l.out_c;
l.inputs = l.w * l.h * l.c;
- l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
@@ -213,13 +189,17 @@
l.scales_gpu = cuda_make_array(l.scales, n);
l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
- l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
if(binary){
l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
}
+ if(xnor){
+ l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+ l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
+ }
+ l.xnor = xnor;
if(batch_normalize){
l.mean_gpu = cuda_make_array(l.mean, n);
@@ -234,7 +214,50 @@
l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
}
+#ifdef CUDNN
+ cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+ cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+ cudnnCreateFilterDescriptor(&l.filterDesc);
+ cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+ cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+ cudnnCreateFilterDescriptor(&l.dfilterDesc);
+ cudnnCreateConvolutionDescriptor(&l.convDesc);
+ cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
+ cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
+ cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
+
+ cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
+ cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
+ cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
+ int padding = l.pad ? l.size/2 : 0;
+ cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
+ cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+ l.srcTensorDesc,
+ l.filterDesc,
+ l.convDesc,
+ l.dstTensorDesc,
+ CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+ 0,
+ &l.fw_algo);
+ cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+ l.filterDesc,
+ l.ddstTensorDesc,
+ l.convDesc,
+ l.dsrcTensorDesc,
+ CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+ 0,
+ &l.bd_algo);
+ cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+ l.srcTensorDesc,
+ l.ddstTensorDesc,
+ l.convDesc,
+ l.dfilterDesc,
+ CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+ 0,
+ &l.bf_algo);
#endif
+#endif
+ l.workspace_size = get_workspace_size(l);
l.activation = activation;
fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@@ -256,7 +279,7 @@
void test_convolutional_layer()
{
- convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0);
+ convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,
@@ -291,22 +314,54 @@
l->outputs = l->out_h * l->out_w * l->out_c;
l->inputs = l->w * l->h * l->c;
- l->col_image = realloc(l->col_image,
- out_h*out_w*l->size*l->size*l->c*sizeof(float));
l->output = realloc(l->output,
l->batch*out_h * out_w * l->n*sizeof(float));
l->delta = realloc(l->delta,
l->batch*out_h * out_w * l->n*sizeof(float));
#ifdef GPU
- cuda_free(l->col_image_gpu);
cuda_free(l->delta_gpu);
cuda_free(l->output_gpu);
- l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
l->delta_gpu = cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
l->output_gpu = cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
+ #ifdef CUDNN
+ cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
+ cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
+ cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+
+ cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
+ cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
+ cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+ int padding = l->pad ? l->size/2 : 0;
+ cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
+ cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
+ l->srcTensorDesc,
+ l->filterDesc,
+ l->convDesc,
+ l->dstTensorDesc,
+ CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+ 0,
+ &l->fw_algo);
+ cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
+ l->filterDesc,
+ l->ddstTensorDesc,
+ l->convDesc,
+ l->dsrcTensorDesc,
+ CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+ 0,
+ &l->bd_algo);
+ cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
+ l->srcTensorDesc,
+ l->ddstTensorDesc,
+ l->convDesc,
+ l->dfilterDesc,
+ CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+ 0,
+ &l->bf_algo);
+ #endif
#endif
+ l->workspace_size = get_workspace_size(*l);
}
void add_bias(float *output, float *biases, int batch, int n, int size)
@@ -351,12 +406,12 @@
fill_cpu(l.outputs*l.batch, 0, l.output, 1);
/*
- if(l.binary){
- binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
- binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
- swap_binary(&l);
- }
- */
+ if(l.binary){
+ binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+ binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+ swap_binary(&l);
+ }
+ */
if(l.binary){
int m = l.n;
@@ -364,7 +419,7 @@
int n = out_h*out_w;
char *a = l.cfilters;
- float *b = l.col_image;
+ float *b = state.workspace;
float *c = l.output;
for(i = 0; i < l.batch; ++i){
@@ -385,7 +440,7 @@
int n = out_h*out_w;
float *a = l.filters;
- float *b = l.col_image;
+ float *b = state.workspace;
float *c = l.output;
for(i = 0; i < l.batch; ++i){
@@ -397,14 +452,7 @@
}
if(l.batch_normalize){
- if(state.train){
- mean_cpu(l.output, l.batch, l.n, l.out_h*l.out_w, l.mean);
- variance_cpu(l.output, l.mean, l.batch, l.n, l.out_h*l.out_w, l.variance);
- normalize_cpu(l.output, l.mean, l.variance, l.batch, l.n, l.out_h*l.out_w);
- } else {
- normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.n, l.out_h*l.out_w);
- }
- scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
+ forward_batchnorm_layer(l, state);
}
add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
@@ -424,7 +472,7 @@
for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
- float *b = l.col_image;
+ float *b = state.workspace;
float *c = l.filter_updates;
float *im = state.input+i*l.c*l.h*l.w;
@@ -436,11 +484,11 @@
if(state.delta){
a = l.filters;
b = l.delta + i*m*k;
- c = l.col_image;
+ c = state.workspace;
gemm(1,0,n,k,m,1,a,n,b,k,0,c,k);
- col2im_cpu(l.col_image, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
+ col2im_cpu(state.workspace, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
}
}
}
--
Gitblit v1.10.0