From c62b4f35aa2c59d7db0fd177affeed14b1ba4bcb Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 08 Sep 2016 07:04:39 +0000
Subject: [PATCH] adding coco models

---
 src/convolutional_kernels.cu |   34 +++++++++++++++-------------------
 1 files changed, 15 insertions(+), 19 deletions(-)

diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 1de9dc0..43b3f9a 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -17,7 +17,7 @@
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (i >= n) return;
-    binary[i] = (x[i] > 0) ? 1 : -1;
+    binary[i] = (x[i] >= 0) ? 1 : -1;
 }
 
 void binarize_gpu(float *x, int n, float *binary)
@@ -60,6 +60,7 @@
     mean = mean / size;
     for(i = 0; i < size; ++i){
         binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
+        //binary[f*size + i] = filters[f*size + i];
     }
 }
 
@@ -71,12 +72,6 @@
 
 void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
-    int i;
-    int m = l.n;
-    int k = l.size*l.size*l.c;
-    int n = convolutional_out_height(l)*
-        convolutional_out_width(l);
-
     fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
     if(l.binary){
         binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
@@ -86,9 +81,7 @@
     if(l.xnor){
         binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
         swap_binary(&l);
-        for(i = 0; i < l.batch; ++i){
-            binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
-        }
+        binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
         state.input = l.binary_input_gpu;
     }
 
@@ -109,6 +102,10 @@
                 l.output_gpu);
 
 #else
+    int i;
+    int m = l.n;
+    int k = l.size*l.size*l.c;
+    int n = l.out_w*l.out_h;
     for(i = 0; i < l.batch; ++i){
         im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
         float * a = l.filters_gpu;
@@ -121,23 +118,18 @@
     if (l.batch_normalize) {
         forward_batchnorm_layer_gpu(l, state);
     }
-    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
+    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
 
-    activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
+    activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
     //if(l.dot > 0) dot_error_gpu(l);
     if(l.binary || l.xnor) swap_binary(&l);
 }
 
 void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
-    int m = l.n;
-    int n = l.size*l.size*l.c;
-    int k = convolutional_out_height(l)*
-        convolutional_out_width(l);
+    gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
 
-    gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
-
-    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
+    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
 
     if(l.batch_normalize){
         backward_batchnorm_layer_gpu(l, state);
@@ -181,6 +173,10 @@
     }
 
 #else
+    int m = l.n;
+    int n = l.size*l.size*l.c;
+    int k = l.out_w*l.out_h;
+
     int i;
     for(i = 0; i < l.batch; ++i){
         float * a = l.delta_gpu;

--
Gitblit v1.10.0