From f047cfff99e00e28c02eb59b6d32386c122f9af6 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 08 Mar 2015 18:31:12 +0000
Subject: [PATCH] renamed sigmoid to logistic
---
src/convolutional_layer.c | 59 ++++++++++++++++++++++++++++++-----------------------------
1 files changed, 30 insertions(+), 29 deletions(-)
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 62118e4..7782e3d 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -44,7 +44,6 @@
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay)
{
int i;
- size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter...
convolutional_layer *layer = calloc(1, sizeof(convolutional_layer));
layer->learning_rate = learning_rate;
@@ -66,12 +65,9 @@
layer->biases = calloc(n, sizeof(float));
layer->bias_updates = calloc(n, sizeof(float));
float scale = 1./sqrt(size*size*c);
- //scale = .01;
for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal();
for(i = 0; i < n; ++i){
- //layer->biases[i] = rand_normal()*scale + scale;
layer->biases[i] = scale;
- //layer->biases[i] = 1;
}
int out_h = convolutional_out_height(*layer);
int out_w = convolutional_out_width(*layer);
@@ -98,11 +94,10 @@
return layer;
}
-void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c)
+void resize_convolutional_layer(convolutional_layer *layer, int h, int w)
{
layer->h = h;
layer->w = w;
- layer->c = c;
int out_h = convolutional_out_height(*layer);
int out_w = convolutional_out_width(*layer);
@@ -112,29 +107,49 @@
layer->batch*out_h * out_w * layer->n*sizeof(float));
layer->delta = realloc(layer->delta,
layer->batch*out_h * out_w * layer->n*sizeof(float));
+
+ #ifdef GPU
+ cuda_free(layer->col_image_gpu);
+ cuda_free(layer->delta_gpu);
+ cuda_free(layer->output_gpu);
+
+ layer->col_image_gpu = cuda_make_array(layer->col_image, out_h*out_w*layer->size*layer->size*layer->c);
+ layer->delta_gpu = cuda_make_array(layer->delta, layer->batch*out_h*out_w*layer->n);
+ layer->output_gpu = cuda_make_array(layer->output, layer->batch*out_h*out_w*layer->n);
+ #endif
}
-void bias_output(const convolutional_layer layer)
+void bias_output(float *output, float *biases, int batch, int n, int size)
{
int i,j,b;
- int out_h = convolutional_out_height(layer);
- int out_w = convolutional_out_width(layer);
- for(b = 0; b < layer.batch; ++b){
- for(i = 0; i < layer.n; ++i){
- for(j = 0; j < out_h*out_w; ++j){
- layer.output[(b*layer.n + i)*out_h*out_w + j] = layer.biases[i];
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < n; ++i){
+ for(j = 0; j < size; ++j){
+ output[(b*n + i)*size + j] = biases[i];
}
}
}
}
+void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
+{
+ float alpha = 1./batch;
+ int i,b;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < n; ++i){
+ bias_updates[i] += alpha * sum_array(delta+size*(i+b*n), size);
+ }
+ }
+}
+
+
void forward_convolutional_layer(const convolutional_layer layer, float *in)
{
int out_h = convolutional_out_height(layer);
int out_w = convolutional_out_width(layer);
int i;
- bias_output(layer);
+ bias_output(layer.output, layer.biases, layer.batch, layer.n, out_h*out_w);
int m = layer.n;
int k = layer.size*layer.size*layer.c;
@@ -154,19 +169,6 @@
activate_array(layer.output, m*n*layer.batch, layer.activation);
}
-void learn_bias_convolutional_layer(convolutional_layer layer)
-{
- float alpha = 1./layer.batch;
- int i,b;
- int size = convolutional_out_height(layer)
- *convolutional_out_width(layer);
- for(b = 0; b < layer.batch; ++b){
- for(i = 0; i < layer.n; ++i){
- layer.bias_updates[i] += alpha * sum_array(layer.delta+size*(i+b*layer.n), size);
- }
- }
-}
-
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta)
{
float alpha = 1./layer.batch;
@@ -177,8 +179,7 @@
convolutional_out_width(layer);
gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
-
- learn_bias_convolutional_layer(layer);
+ backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, k);
if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
--
Gitblit v1.10.0