From c40cdeb4021fc1a638969563972f13c9f5e90d74 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 09 Oct 2015 19:50:43 +0000
Subject: [PATCH] lots of comparator stuff

---
 src/convolutional_layer.c |   23 ++++++++++++++++++-----
 1 files changed, 18 insertions(+), 5 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 378e23f..f3609ea 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -61,7 +61,7 @@
 
     l.biases = calloc(n, sizeof(float));
     l.bias_updates = calloc(n, sizeof(float));
-    //float scale = 1./sqrt(size*size*c);
+    // float scale = 1./sqrt(size*size*c);
     float scale = sqrt(2./(size*size*c));
     for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale;
     for(i = 0; i < n; ++i){
@@ -122,9 +122,9 @@
     cuda_free(l->delta_gpu);
     cuda_free(l->output_gpu);
 
-    l->col_image_gpu = cuda_make_array(0, out_h*out_w*l->size*l->size*l->c);
-    l->delta_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n);
-    l->output_gpu = cuda_make_array(0, l->batch*out_h*out_w*l->n);
+    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
+    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
+    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
     #endif
 }
 
@@ -242,13 +242,26 @@
     }
 }
 
+void rescale_filters(convolutional_layer l, float scale, float trans)
+{
+    int i;
+    for(i = 0; i < l.n; ++i){
+        image im = get_convolutional_filter(l, i);
+        if (im.c == 3) {
+            scale_image(im, scale);
+            float sum = sum_array(im.data, im.w*im.h*im.c);
+            l.biases[i] += sum*trans;
+        }
+    }
+}
+
 image *get_filters(convolutional_layer l)
 {
     image *filters = calloc(l.n, sizeof(image));
     int i;
     for(i = 0; i < l.n; ++i){
         filters[i] = copy_image(get_convolutional_filter(l, i));
-        normalize_image(filters[i]);
+        //normalize_image(filters[i]);
     }
     return filters;
 }

--
Gitblit v1.10.0