From 00d483697a6e395ef6776320cd1e52a04f4367be Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 30 Apr 2014 23:17:40 +0000
Subject: [PATCH] Small updates

---
 src/network.c             |    1 
 src/gemm.cl               |   29 -----
 src/mini_blas.c           |    2 
 src/cpu_gemm.c            |    7 +
 Makefile                  |    4 
 src/convolutional_layer.c |   86 +----------------
 src/activations.h         |    2 
 src/connected_layer.c     |   60 -----------
 src/mini_blas.h           |    1 
 src/activations.c         |   34 +++++-
 src/convolutional_layer.h |    4 
 src/tests.c               |   12 +-
 12 files changed, 53 insertions(+), 189 deletions(-)

diff --git a/Makefile b/Makefile
index 2c47fdf..3b01ab2 100644
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
 CC=gcc
-GPU=0
+GPU=1
 COMMON=-Wall `pkg-config --cflags opencv` -I/usr/local/cuda/include/
 UNAME = $(shell uname)
 OPTS=-O3
@@ -15,7 +15,7 @@
 endif
 endif
 CFLAGS= $(COMMON) $(OPTS)
-#CFLAGS= $(COMMON) -O0 -g 
+CFLAGS= $(COMMON) -O0 -g 
 LDFLAGS+=`pkg-config --libs opencv` -lm
 VPATH=./src/
 EXEC=cnn
diff --git a/src/activations.c b/src/activations.c
index c81d6aa..24868a3 100644
--- a/src/activations.c
+++ b/src/activations.c
@@ -34,21 +34,37 @@
     return RELU;
 }
 
+float linear_activate(float x){return x;}
+float sigmoid_activate(float x){return 1./(1. + exp(-x));}
+float relu_activate(float x){return x*(x>0);}
+float ramp_activate(float x){return x*(x>0)+.1*x;}
+float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
+
 float activate(float x, ACTIVATION a){
     switch(a){
         case LINEAR:
-            return x;
+            return linear_activate(x);
         case SIGMOID:
-            return 1./(1.+exp(-x));
+            return sigmoid_activate(x);
         case RELU:
-            return x*(x>0);
+            return relu_activate(x);
         case RAMP:
-            return x*(x>0) + .1*x;
+            return ramp_activate(x);
         case TANH:
-            return (exp(2*x)-1)/(exp(2*x)+1);
+            return tanh_activate(x);
     }
     return 0;
 }
+
+void activate_array(float *x, const int n, const ACTIVATION a)
+{
+    int i;
+    for(i = 0; i < n; ++i){
+        x[i] = activate(x[i], a);
+    }
+}
+
+
 float gradient(float x, ACTIVATION a){
     switch(a){
         case LINEAR:
@@ -65,3 +81,11 @@
     return 0;
 }
 
+void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta)
+{
+    int i;
+    for(i = 0; i < n; ++i){
+        delta[i] *= gradient(x[i], a);
+    }
+} 
+
diff --git a/src/activations.h b/src/activations.h
index 9474121..68d2222 100644
--- a/src/activations.h
+++ b/src/activations.h
@@ -10,6 +10,8 @@
 char *get_activation_string(ACTIVATION a);
 float activate(float x, ACTIVATION a);
 float gradient(float x, ACTIVATION a);
+void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
+void activate_array(float *x, const int n, const ACTIVATION a);
 
 #endif
 
diff --git a/src/connected_layer.c b/src/connected_layer.c
index 16a39be..792f20b 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -39,27 +39,6 @@
     return layer;
 }
 
-/*
-void update_connected_layer(connected_layer layer, float step, float momentum, float decay)
-{
-    int i;
-    for(i = 0; i < layer.outputs; ++i){
-        float delta = layer.bias_updates[i];
-        layer.bias_adapt[i] += delta*delta;
-        layer.bias_momentum[i] = step/sqrt(layer.bias_adapt[i])*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i];
-        layer.biases[i] += layer.bias_momentum[i];
-    }
-    for(i = 0; i < layer.outputs*layer.inputs; ++i){
-        float delta = layer.weight_updates[i];
-        layer.weight_adapt[i] += delta*delta;
-        layer.weight_momentum[i] = step/sqrt(layer.weight_adapt[i])*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i];
-        layer.weights[i] += layer.weight_momentum[i];
-    }
-    memset(layer.bias_updates, 0, layer.outputs*sizeof(float));
-    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float));
-}
-*/
-
 void update_connected_layer(connected_layer layer, float step, float momentum, float decay)
 {
     int i;
@@ -89,7 +68,6 @@
     for(i = 0; i < layer.outputs*layer.batch; ++i){
         layer.output[i] = activate(layer.output[i], layer.activation);
     }
-    //for(i = 0; i < layer.outputs; ++i) if(i%(layer.outputs/10+1)==0) printf("%f, ", layer.output[i]); printf("\n");
 }
 
 void learn_connected_layer(connected_layer layer, float *input)
@@ -110,8 +88,6 @@
 
 void backward_connected_layer(connected_layer layer, float *input, float *delta)
 {
-    memset(delta, 0, layer.inputs*sizeof(float));
-
     int m = layer.inputs;
     int k = layer.outputs;
     int n = layer.batch;
@@ -120,40 +96,6 @@
     float *b = layer.delta;
     float *c = delta;
 
-    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+    gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
 }
-/*
-   void forward_connected_layer(connected_layer layer, float *input)
-   {
-   int i, j;
-   for(i = 0; i < layer.outputs; ++i){
-   layer.output[i] = layer.biases[i];
-   for(j = 0; j < layer.inputs; ++j){
-   layer.output[i] += input[j]*layer.weights[i*layer.inputs + j];
-   }
-   layer.output[i] = activate(layer.output[i], layer.activation);
-   }
-   }
-   void learn_connected_layer(connected_layer layer, float *input)
-   {
-   int i, j;
-   for(i = 0; i < layer.outputs; ++i){
-   layer.delta[i] *= gradient(layer.output[i], layer.activation);
-   layer.bias_updates[i] += layer.delta[i];
-   for(j = 0; j < layer.inputs; ++j){
-   layer.weight_updates[i*layer.inputs + j] += layer.delta[i]*input[j];
-   }
-   }
-   }
-   void backward_connected_layer(connected_layer layer, float *input, float *delta)
-   {
-   int i, j;
 
-   for(j = 0; j < layer.inputs; ++j){
-   delta[j] = 0;
-   for(i = 0; i < layer.outputs; ++i){
-   delta[j] += layer.delta[i]*layer.weights[i*layer.inputs + j];
-   }
-   }
-   }
- */
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 6916eeb..45bb54a 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -96,33 +96,14 @@
             convolutional_out_width(layer)*
             layer.batch;
 
-    memset(layer.output, 0, m*n*sizeof(float));
-
     float *a = layer.filters;
     float *b = layer.col_image;
     float *c = layer.output;
     for(i = 0; i < layer.batch; ++i){
         im2col_cpu(in+i*(n/layer.batch),  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b+i*(n/layer.batch));
     }
-    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-
-    for(i = 0; i < m*n; ++i){
-        layer.output[i] = activate(layer.output[i], layer.activation);
-    }
-    //for(i = 0; i < m*n; ++i) if(i%(m*n/10+1)==0) printf("%f, ", layer.output[i]); printf("\n");
-
-}
-
-void gradient_delta_convolutional_layer(convolutional_layer layer)
-{
-    int i;
-    int size = convolutional_out_height(layer)*
-                convolutional_out_width(layer)*
-                layer.n*
-                layer.batch;
-    for(i = 0; i < size; ++i){
-        layer.delta[i] *= gradient(layer.output[i], layer.activation);
-    }
+    gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
+    activate_array(layer.output, m*n, layer.activation);
 }
 
 void learn_bias_convolutional_layer(convolutional_layer layer)
@@ -143,13 +124,13 @@
 
 void learn_convolutional_layer(convolutional_layer layer)
 {
-    gradient_delta_convolutional_layer(layer);
-    learn_bias_convolutional_layer(layer);
     int m = layer.n;
     int n = layer.size*layer.size*layer.c;
     int k = convolutional_out_height(layer)*
             convolutional_out_width(layer)*
             layer.batch;
+    gradient_array(layer.output, m*k, layer.activation, layer.delta);
+    learn_bias_convolutional_layer(layer);
 
     float *a = layer.delta;
     float *b = layer.col_image;
@@ -171,9 +152,7 @@
     float *b = layer.delta;
     float *c = layer.col_image;
 
-
-    memset(c, 0, m*n*sizeof(float));
-    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
+    gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
 
     memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
     for(i = 0; i < layer.batch; ++i){
@@ -194,61 +173,6 @@
         layer.filter_updates[i] *= momentum;
     }
 }
-/*
-
-void backward_convolutional_layer2(convolutional_layer layer, float *input, float *delta)
-{
-    image in_delta = float_to_image(layer.h, layer.w, layer.c, delta);
-    image out_delta = get_convolutional_delta(layer);
-    int i,j;
-    for(i = 0; i < layer.n; ++i){
-        rotate_image(layer.kernels[i]);
-    }
-
-    zero_image(in_delta);
-    upsample_image(out_delta, layer.stride, layer.upsampled);
-    for(j = 0; j < in_delta.c; ++j){
-        for(i = 0; i < layer.n; ++i){
-            two_d_convolve(layer.upsampled, i, layer.kernels[i], j, 1, in_delta, j, layer.edge);
-        }
-    }
-
-    for(i = 0; i < layer.n; ++i){
-        rotate_image(layer.kernels[i]);
-    }
-}
-
-
-void learn_convolutional_layer(convolutional_layer layer, float *input)
-{
-    int i;
-    image in_image = float_to_image(layer.h, layer.w, layer.c, input);
-    image out_delta = get_convolutional_delta(layer);
-    gradient_delta_convolutional_layer(layer);
-    for(i = 0; i < layer.n; ++i){
-        kernel_update(in_image, layer.kernel_updates[i], layer.stride, i, out_delta, layer.edge);
-        layer.bias_updates[i] += avg_image_layer(out_delta, i);
-    }
-}
-
-void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
-{
-    int i,j;
-    for(i = 0; i < layer.n; ++i){
-        layer.bias_momentum[i] = step*(layer.bias_updates[i]) 
-                                + momentum*layer.bias_momentum[i];
-        layer.biases[i] += layer.bias_momentum[i];
-        layer.bias_updates[i] = 0;
-        int pixels = layer.kernels[i].h*layer.kernels[i].w*layer.kernels[i].c;
-        for(j = 0; j < pixels; ++j){
-            layer.kernel_momentum[i].data[j] = step*(layer.kernel_updates[i].data[j] - decay*layer.kernels[i].data[j]) 
-                                                + momentum*layer.kernel_momentum[i].data[j];
-            layer.kernels[i].data[j] += layer.kernel_momentum[i].data[j];
-        }
-        zero_image(layer.kernel_updates[i]);
-    }
-}
-*/
 
 void test_convolutional_layer()
 {
diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h
index 7404def..ef08976 100644
--- a/src/convolutional_layer.h
+++ b/src/convolutional_layer.h
@@ -34,10 +34,6 @@
 
 void backward_convolutional_layer(convolutional_layer layer, float *delta);
 
-//void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer);
-//void visualize_convolutional_filters(convolutional_layer layer, char *window);
-//void visualize_convolutional_layer(convolutional_layer layer);
-
 image get_convolutional_image(convolutional_layer layer);
 image get_convolutional_delta(convolutional_layer layer);
 image get_convolutional_filter(convolutional_layer layer, int i);
diff --git a/src/cpu_gemm.c b/src/cpu_gemm.c
index 437b39a..29c9ff3 100644
--- a/src/cpu_gemm.c
+++ b/src/cpu_gemm.c
@@ -74,7 +74,12 @@
         float BETA,
         float *C, int ldc)
 {
-    // Assume beta = 1 LULZ
+    int i, j;
+    for(i = 0; i < M; ++i){
+        for(j = 0; j < N; ++j){
+            C[i*ldc + j] *= BETA;
+        }
+    }
     if(!TA && !TB)
         cpu_gemm_nn( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
     else if(TA && !TB)
diff --git a/src/gemm.cl b/src/gemm.cl
index 7c868f4..91375a7 100644
--- a/src/gemm.cl
+++ b/src/gemm.cl
@@ -1,5 +1,4 @@
 
-
 __kernel void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
                     __global float *A, int lda, 
                     __global float *B, int ldb,
@@ -40,33 +39,7 @@
     }
 
     if(row < M && col < N){
-        C[row*ldc+col] = val;
+        C[row*ldc+col] = ALPHA*val + BETA*C[row*ldc+col];
     }
 }
 
-/*
-__kernel void gemm_slow(int TA, int TB, int M, int N, int K, float ALPHA, 
-                    __global float *A, int lda, 
-                    __global float *B, int ldb,
-                    float BETA,
-                    __global float *C, int ldc)
-{
-    float val = 0;
-    int row = get_global_id(0);
-    int col = get_global_id(1);
-    int i;
-    for(i = 0; i < K; ++i){
-        float Aval;
-        if(TA) Aval = A[i*lda+row]; 
-        else Aval = A[row*lda+i];
-
-        float Bval;
-        if(TB) Bval = B[col*ldb+i];
-        else Bval = B[col+i*ldb];
-
-        val += Aval*Bval;
-    }
-    C[row*ldc+col] = val;
-}
-
-*/
diff --git a/src/mini_blas.c b/src/mini_blas.c
index 4c7d3d0..70dcb54 100644
--- a/src/mini_blas.c
+++ b/src/mini_blas.c
@@ -24,7 +24,7 @@
         float BETA,
         float *C, int ldc)
 {
-    cpu_gemm( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
+    gpu_gemm( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
 }
 
 void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix)
diff --git a/src/mini_blas.h b/src/mini_blas.h
index 56e4fa7..31af193 100644
--- a/src/mini_blas.h
+++ b/src/mini_blas.h
@@ -5,6 +5,7 @@
                     float BETA,
                     float *C, int ldc);
 float *random_matrix(int rows, int cols);
+void time_random_matrix(int TA, int TB, int m, int k, int n);
 void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix);
 void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix);
 void im2col_cpu(float* data_im, const int channels,
diff --git a/src/network.c b/src/network.c
index 7d4b1fa..a77a28e 100644
--- a/src/network.c
+++ b/src/network.c
@@ -6,7 +6,6 @@
 
 #include "connected_layer.h"
 #include "convolutional_layer.h"
-//#include "old_conv.h"
 #include "maxpool_layer.h"
 #include "normalization_layer.h"
 #include "softmax_layer.h"
diff --git a/src/tests.c b/src/tests.c
index 851d781..1c46b24 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -302,9 +302,9 @@
 {
 	srand(444444);
 	srand(888888);
-	network net = parse_network_cfg("cfg/nist_basic.cfg");
-	data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
-	data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
+	network net = parse_network_cfg("cfg/nist.cfg");
+	data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
+	data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
 	normalize_data_rows(train);
 	normalize_data_rows(test);
 	//randomize_data(train);
@@ -655,9 +655,7 @@
 	resize_network(net, im.h, im.w, im.c);
 	forward_network(net, im.data);
 
-	image out = get_network_image(net);
 	visualize_network(net);
-	cvWaitKey(1000);
 	cvWaitKey(0);
 }
 
@@ -784,14 +782,14 @@
 	//    test_im2row();
 	//test_split();
 	//test_ensemble();
-	//test_nist();
+	test_nist();
 	//test_cifar10();
 	//test_vince();
 	//test_full();
 	//train_VOC();
 	//features_VOC_image(argv[1], argv[2], argv[3], 0);
 	//features_VOC_image(argv[1], argv[2], argv[3], 1);
-	features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
+	//features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
 	//visualize_imagenet_features("data/assira/train.list");
 	//visualize_imagenet_topk("data/VOC2012.list");
 	//visualize_cat();

--
Gitblit v1.10.0