From 5c067dc44785a761a0243d8cd634e3ac17d548ad Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 12 Sep 2016 20:55:20 +0000
Subject: [PATCH] good chance I didn't break anything
---
src/convolutional_layer.c | 110 +++++++++++++++++++++++++++---------------------------
1 files changed, 55 insertions(+), 55 deletions(-)
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index ad2d8a5..299af75 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -19,28 +19,28 @@
void swap_binary(convolutional_layer *l)
{
- float *swap = l->filters;
- l->filters = l->binary_filters;
- l->binary_filters = swap;
+ float *swap = l->weights;
+ l->weights = l->binary_weights;
+ l->binary_weights = swap;
#ifdef GPU
- swap = l->filters_gpu;
- l->filters_gpu = l->binary_filters_gpu;
- l->binary_filters_gpu = swap;
+ swap = l->weights_gpu;
+ l->weights_gpu = l->binary_weights_gpu;
+ l->binary_weights_gpu = swap;
#endif
}
-void binarize_filters(float *filters, int n, int size, float *binary)
+void binarize_weights(float *weights, int n, int size, float *binary)
{
int i, f;
for(f = 0; f < n; ++f){
float mean = 0;
for(i = 0; i < size; ++i){
- mean += fabs(filters[f*size + i]);
+ mean += fabs(weights[f*size + i]);
}
mean = mean / size;
for(i = 0; i < size; ++i){
- binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
+ binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
}
}
}
@@ -103,7 +103,7 @@
size_t s = 0;
cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
l.srcTensorDesc,
- l.filterDesc,
+ l.weightDesc,
l.convDesc,
l.dstTensorDesc,
l.fw_algo,
@@ -113,12 +113,12 @@
l.srcTensorDesc,
l.ddstTensorDesc,
l.convDesc,
- l.dfilterDesc,
+ l.dweightDesc,
l.bf_algo,
&s);
if (s > most) most = s;
cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
- l.filterDesc,
+ l.weightDesc,
l.ddstTensorDesc,
l.convDesc,
l.dsrcTensorDesc,
@@ -137,22 +137,22 @@
{
cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
- cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+ cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
- cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+ cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
l->srcTensorDesc,
- l->filterDesc,
+ l->weightDesc,
l->convDesc,
l->dstTensorDesc,
CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
0,
&l->fw_algo);
cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
- l->filterDesc,
+ l->weightDesc,
l->ddstTensorDesc,
l->convDesc,
l->dsrcTensorDesc,
@@ -163,7 +163,7 @@
l->srcTensorDesc,
l->ddstTensorDesc,
l->convDesc,
- l->dfilterDesc,
+ l->dweightDesc,
CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
0,
&l->bf_algo);
@@ -189,15 +189,15 @@
l.pad = padding;
l.batch_normalize = batch_normalize;
- l.filters = calloc(c*n*size*size, sizeof(float));
- l.filter_updates = calloc(c*n*size*size, sizeof(float));
+ l.weights = calloc(c*n*size*size, sizeof(float));
+ l.weight_updates = calloc(c*n*size*size, sizeof(float));
l.biases = calloc(n, sizeof(float));
l.bias_updates = calloc(n, sizeof(float));
// float scale = 1./sqrt(size*size*c);
float scale = sqrt(2./(size*size*c));
- for(i = 0; i < c*n*size*size; ++i) l.filters[i] = scale*rand_uniform(-1, 1);
+ for(i = 0; i < c*n*size*size; ++i) l.weights[i] = scale*rand_uniform(-1, 1);
int out_h = convolutional_out_height(l);
int out_w = convolutional_out_width(l);
l.out_h = out_h;
@@ -210,12 +210,12 @@
l.delta = calloc(l.batch*out_h * out_w * n, sizeof(float));
if(binary){
- l.binary_filters = calloc(c*n*size*size, sizeof(float));
- l.cfilters = calloc(c*n*size*size, sizeof(char));
+ l.binary_weights = calloc(c*n*size*size, sizeof(float));
+ l.cweights = calloc(c*n*size*size, sizeof(char));
l.scales = calloc(n, sizeof(float));
}
if(xnor){
- l.binary_filters = calloc(c*n*size*size, sizeof(float));
+ l.binary_weights = calloc(c*n*size*size, sizeof(float));
l.binary_input = calloc(l.inputs*l.batch, sizeof(float));
}
@@ -235,8 +235,8 @@
#ifdef GPU
if(gpu_index >= 0){
- l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
- l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
+ l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
+ l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
l.biases_gpu = cuda_make_array(l.biases, n);
l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
@@ -248,10 +248,10 @@
l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
if(binary){
- l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+ l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
}
if(xnor){
- l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+ l.binary_weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
}
@@ -271,10 +271,10 @@
#ifdef CUDNN
cudnnCreateTensorDescriptor(&l.srcTensorDesc);
cudnnCreateTensorDescriptor(&l.dstTensorDesc);
- cudnnCreateFilterDescriptor(&l.filterDesc);
+ cudnnCreateFilterDescriptor(&l.weightDesc);
cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
- cudnnCreateFilterDescriptor(&l.dfilterDesc);
+ cudnnCreateFilterDescriptor(&l.dweightDesc);
cudnnCreateConvolutionDescriptor(&l.convDesc);
cudnn_convolutional_setup(&l);
#endif
@@ -294,7 +294,7 @@
for(i = 0; i < l.n; ++i){
float scale = l.scales[i]/sqrt(l.rolling_variance[i] + .00001);
for(j = 0; j < l.c*l.size*l.size; ++j){
- l.filters[i*l.c*l.size*l.size + j] *= scale;
+ l.weights[i*l.c*l.size*l.size + j] *= scale;
}
l.biases[i] -= l.rolling_mean[i] * scale;
l.scales[i] = 1;
@@ -403,8 +403,8 @@
/*
if(l.binary){
- binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
- binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
+ binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
+ binarize_weights2(l.weights, l.n, l.c*l.size*l.size, l.cweights, l.scales);
swap_binary(&l);
}
*/
@@ -415,7 +415,7 @@
int k = l.size*l.size*l.c;
int n = out_h*out_w;
- char *a = l.cfilters;
+ char *a = l.cweights;
float *b = state.workspace;
float *c = l.output;
@@ -434,7 +434,7 @@
*/
if(l.xnor){
- binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
+ binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
swap_binary(&l);
binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
state.input = l.binary_input;
@@ -449,7 +449,7 @@
printf("xnor\n");
} else {
- float *a = l.filters;
+ float *a = l.weights;
float *b = state.workspace;
float *c = l.output;
@@ -485,7 +485,7 @@
for(i = 0; i < l.batch; ++i){
float *a = l.delta + i*m*k;
float *b = state.workspace;
- float *c = l.filter_updates;
+ float *c = l.weight_updates;
float *im = state.input+i*l.c*l.h*l.w;
@@ -494,7 +494,7 @@
gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
if(state.delta){
- a = l.filters;
+ a = l.weights;
b = l.delta + i*m*k;
c = state.workspace;
@@ -511,36 +511,36 @@
axpy_cpu(l.n, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
scal_cpu(l.n, momentum, l.bias_updates, 1);
- axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
- axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
- scal_cpu(size, momentum, l.filter_updates, 1);
+ axpy_cpu(size, -decay*batch, l.weights, 1, l.weight_updates, 1);
+ axpy_cpu(size, learning_rate/batch, l.weight_updates, 1, l.weights, 1);
+ scal_cpu(size, momentum, l.weight_updates, 1);
}
-image get_convolutional_filter(convolutional_layer l, int i)
+image get_convolutional_weight(convolutional_layer l, int i)
{
int h = l.size;
int w = l.size;
int c = l.c;
- return float_to_image(w,h,c,l.filters+i*h*w*c);
+ return float_to_image(w,h,c,l.weights+i*h*w*c);
}
-void rgbgr_filters(convolutional_layer l)
+void rgbgr_weights(convolutional_layer l)
{
int i;
for(i = 0; i < l.n; ++i){
- image im = get_convolutional_filter(l, i);
+ image im = get_convolutional_weight(l, i);
if (im.c == 3) {
rgbgr_image(im);
}
}
}
-void rescale_filters(convolutional_layer l, float scale, float trans)
+void rescale_weights(convolutional_layer l, float scale, float trans)
{
int i;
for(i = 0; i < l.n; ++i){
- image im = get_convolutional_filter(l, i);
+ image im = get_convolutional_weight(l, i);
if (im.c == 3) {
scale_image(im, scale);
float sum = sum_array(im.data, im.w*im.h*im.c);
@@ -549,21 +549,21 @@
}
}
-image *get_filters(convolutional_layer l)
+image *get_weights(convolutional_layer l)
{
- image *filters = calloc(l.n, sizeof(image));
+ image *weights = calloc(l.n, sizeof(image));
int i;
for(i = 0; i < l.n; ++i){
- filters[i] = copy_image(get_convolutional_filter(l, i));
- //normalize_image(filters[i]);
+ weights[i] = copy_image(get_convolutional_weight(l, i));
+ //normalize_image(weights[i]);
}
- return filters;
+ return weights;
}
-image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_filters)
+image *visualize_convolutional_layer(convolutional_layer l, char *window, image *prev_weights)
{
- image *single_filters = get_filters(l);
- show_images(single_filters, l.n, window);
+ image *single_weights = get_weights(l);
+ show_images(single_weights, l.n, window);
image delta = get_convolutional_image(l);
image dc = collapse_image_layers(delta, 1);
@@ -572,6 +572,6 @@
//show_image(dc, buff);
//save_image(dc, buff);
free_image(dc);
- return single_filters;
+ return single_weights;
}
--
Gitblit v1.10.0