From 796e464d43274415603e6f27a4bb81b6c1ef8cf3 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 24 Jan 2014 22:49:02 +0000
Subject: [PATCH] Connected layers use matrices
---
src/image.c | 1
dog.jpg | 0
src/mini_blas.c | 80 ++++++++++++++++
Makefile | 2
src/connected_layer.c | 109 +++++++++++++++------
src/mini_blas.h | 10 ++
src/tests.c | 84 ++++++++++++++++
7 files changed, 249 insertions(+), 37 deletions(-)
diff --git a/Makefile b/Makefile
index 44c930f..415c522 100644
--- a/Makefile
+++ b/Makefile
@@ -11,7 +11,7 @@
LDFLAGS=`pkg-config --libs opencv` -lm
VPATH=./src/
-OBJ=network.o image.o tests.o convolutional_layer.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o
+OBJ=network.o image.o tests.o convolutional_layer.o connected_layer.o maxpool_layer.o activations.o list.o option_list.o parser.o utils.o data.o matrix.o softmax_layer.o mini_blas.o
all: cnn
diff --git a/dog.jpg b/dog.jpg
index 16d05ab..3b9f7ab 100644
--- a/dog.jpg
+++ b/dog.jpg
Binary files differ
diff --git a/src/connected_layer.c b/src/connected_layer.c
index 0344c71..6871b2e 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -1,5 +1,6 @@
#include "connected_layer.h"
#include "utils.h"
+#include "mini_blas.h"
#include <math.h>
#include <stdio.h>
@@ -35,55 +36,99 @@
return layer;
}
+void update_connected_layer(connected_layer layer, double step, double momentum, double decay)
+{
+ int i;
+ for(i = 0; i < layer.outputs; ++i){
+ layer.bias_momentum[i] = step*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i];
+ layer.biases[i] += layer.bias_momentum[i];
+ }
+ for(i = 0; i < layer.outputs*layer.inputs; ++i){
+ layer.weight_momentum[i] = step*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i];
+ layer.weights[i] += layer.weight_momentum[i];
+ }
+ memset(layer.bias_updates, 0, layer.outputs*sizeof(double));
+ memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double));
+}
+
void forward_connected_layer(connected_layer layer, double *input)
{
- int i, j;
+ int i;
+ memcpy(layer.output, layer.biases, layer.outputs*sizeof(double));
+ int m = 1;
+ int k = layer.inputs;
+ int n = layer.outputs;
+ double *a = input;
+ double *b = layer.weights;
+ double *c = layer.output;
+ gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
for(i = 0; i < layer.outputs; ++i){
- layer.output[i] = layer.biases[i];
- for(j = 0; j < layer.inputs; ++j){
- layer.output[i] += input[j]*layer.weights[i*layer.inputs + j];
- }
layer.output[i] = activate(layer.output[i], layer.activation);
}
}
void learn_connected_layer(connected_layer layer, double *input)
{
- int i, j;
+ int i;
for(i = 0; i < layer.outputs; ++i){
layer.delta[i] *= gradient(layer.output[i], layer.activation);
layer.bias_updates[i] += layer.delta[i];
- for(j = 0; j < layer.inputs; ++j){
- layer.weight_updates[i*layer.inputs + j] += layer.delta[i]*input[j];
- }
}
-}
-
-void update_connected_layer(connected_layer layer, double step, double momentum, double decay)
-{
- int i,j;
- for(i = 0; i < layer.outputs; ++i){
- layer.bias_momentum[i] = step*(layer.bias_updates[i]) + momentum*layer.bias_momentum[i];
- layer.biases[i] += layer.bias_momentum[i];
- for(j = 0; j < layer.inputs; ++j){
- int index = i*layer.inputs+j;
- layer.weight_momentum[index] = step*(layer.weight_updates[index] - decay*layer.weights[index]) + momentum*layer.weight_momentum[index];
- layer.weights[index] += layer.weight_momentum[index];
- }
- }
- memset(layer.bias_updates, 0, layer.outputs*sizeof(double));
- memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double));
+ int m = layer.inputs;
+ int k = 1;
+ int n = layer.outputs;
+ double *a = input;
+ double *b = layer.delta;
+ double *c = layer.weight_updates;
+ gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
void backward_connected_layer(connected_layer layer, double *input, double *delta)
{
- int i, j;
+ memset(delta, 0, layer.inputs*sizeof(double));
- for(j = 0; j < layer.inputs; ++j){
- delta[j] = 0;
- for(i = 0; i < layer.outputs; ++i){
- delta[j] += layer.delta[i]*layer.weights[i*layer.inputs + j];
- }
- }
+ int m = layer.inputs;
+ int k = layer.outputs;
+ int n = 1;
+
+ double *a = layer.weights;
+ double *b = layer.delta;
+ double *c = delta;
+
+ gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
+/*
+ void forward_connected_layer(connected_layer layer, double *input)
+ {
+ int i, j;
+ for(i = 0; i < layer.outputs; ++i){
+ layer.output[i] = layer.biases[i];
+ for(j = 0; j < layer.inputs; ++j){
+ layer.output[i] += input[j]*layer.weights[i*layer.inputs + j];
+ }
+ layer.output[i] = activate(layer.output[i], layer.activation);
+ }
+ }
+ void learn_connected_layer(connected_layer layer, double *input)
+ {
+ int i, j;
+ for(i = 0; i < layer.outputs; ++i){
+ layer.delta[i] *= gradient(layer.output[i], layer.activation);
+ layer.bias_updates[i] += layer.delta[i];
+ for(j = 0; j < layer.inputs; ++j){
+ layer.weight_updates[i*layer.inputs + j] += layer.delta[i]*input[j];
+ }
+ }
+ }
+ void backward_connected_layer(connected_layer layer, double *input, double *delta)
+ {
+ int i, j;
+ for(j = 0; j < layer.inputs; ++j){
+ delta[j] = 0;
+ for(i = 0; i < layer.outputs; ++i){
+ delta[j] += layer.delta[i]*layer.weights[i*layer.inputs + j];
+ }
+ }
+ }
+ */
diff --git a/src/image.c b/src/image.c
index 74b8832..df8e1b8 100644
--- a/src/image.c
+++ b/src/image.c
@@ -207,6 +207,7 @@
int i;
for(i = 0; i < h*w*c; ++i){
out.data[i] = rand_normal();
+ //out.data[i] = rand()%3;
}
return out;
}
diff --git a/src/mini_blas.c b/src/mini_blas.c
new file mode 100644
index 0000000..b15ba8e
--- /dev/null
+++ b/src/mini_blas.c
@@ -0,0 +1,80 @@
+
+void gemm(int TA, int TB, int M, int N, int K, double ALPHA,
+ double *A, int lda,
+ double *B, int ldb,
+ double BETA,
+ double *C, int ldc)
+{
+ // Assume TA = TB = 0, beta = 1 LULZ
+ int i,j,k;
+ for(i = 0; i < M; ++i){
+ for(k = 0; k < K; ++k){
+ for(j = 0; j < N; ++j){
+ C[i*ldc+j] += ALPHA*A[i*lda+k]*B[k*ldb+j];
+ }
+ }
+ }
+}
+
+void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix)
+{
+ int i;
+ int mc = c;
+ int mw = (size*size);
+ int mh = ((h-size)/stride+1)*((w-size)/stride+1);
+ int msize = mc*mw*mh;
+ for(i = 0; i < msize; ++i){
+ int channel = i/(mh*mw);
+ int block = (i%(mh*mw))/mw;
+ int position = i%mw;
+ int block_h = block/((w-size)/stride+1);
+ int block_w = block%((w-size)/stride+1);
+ int ph, pw, pc;
+ ph = position/size+block_h;
+ pw = position%size+block_w;
+ pc = channel;
+ matrix[i] = image[pc*h*w+ph*w+pw];
+ }
+}
+void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix)
+{
+ int b,p;
+ int blocks = ((h-size)/stride+1)*((w-size)/stride+1);
+ int pixels = (size*size*c);
+ for(b = 0; b < blocks; ++b){
+ int block_h = b/((w-size)/stride+1);
+ int block_w = b%((w-size)/stride+1);
+ for(p = 0; p < pixels; ++p){
+ int ph, pw, pc;
+ int position = p%(size*size);
+ pc = p/(size*size);
+ ph = position/size+block_h;
+ pw = position%size+block_w;
+ matrix[b+p*blocks] = image[pc*h*w+ph*w+pw];
+ }
+ }
+}
+
+//From Berkeley Vision's Caffe!
+void im2col_cpu(double* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ double* data_col)
+ {
+ int c,h,w;
+ int height_col = (height - ksize) / stride + 1;
+ int width_col = (width - ksize) / stride + 1;
+ int channels_col = channels * ksize * ksize;
+ for ( c = 0; c < channels_col; ++c) {
+ int w_offset = c % ksize;
+ int h_offset = (c / ksize) % ksize;
+ int c_im = c / ksize / ksize;
+ for ( h = 0; h < height_col; ++h) {
+ for ( w = 0; w < width_col; ++w) {
+ data_col[(c * height_col + h) * width_col + w] =
+ data_im[(c_im * height + h * stride + h_offset) * width
+ + w * stride + w_offset];
+ }
+ }
+ }
+}
+
diff --git a/src/mini_blas.h b/src/mini_blas.h
new file mode 100644
index 0000000..cdf3a86
--- /dev/null
+++ b/src/mini_blas.h
@@ -0,0 +1,10 @@
+void gemm(int TA, int TB, int M, int N, int K, double ALPHA,
+ double *A, int lda,
+ double *B, int ldb,
+ double BETA,
+ double *C, int ldc);
+void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix);
+void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix);
+void im2col_cpu(double* data_im, const int channels,
+ const int height, const int width, const int ksize, const int stride,
+ double* data_col);
diff --git a/src/tests.c b/src/tests.c
index 2a50bac..ce131e7 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -7,6 +7,7 @@
#include "data.h"
#include "matrix.h"
#include "utils.h"
+#include "mini_blas.h"
#include <time.h>
#include <stdlib.h>
@@ -28,6 +29,35 @@
show_image_layers(edge, "Test Convolve");
}
+void test_convolve_matrix()
+{
+ image dog = load_image("dog.jpg");
+ printf("dog channels %d\n", dog.c);
+
+ int size = 11;
+ int stride = 1;
+ int n = 40;
+ double *filters = make_random_image(size, size, dog.c*n).data;
+
+ int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1);
+ int mh = (size*size*dog.c);
+ double *matrix = calloc(mh*mw, sizeof(double));
+
+ image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n);
+
+
+ int i;
+ clock_t start = clock(), end;
+ for(i = 0; i < 1000; ++i){
+ im2col_cpu(dog.data, dog.c, dog.h, dog.w, size, stride, matrix);
+ gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
+ }
+ end = clock();
+ printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+ show_image_layers(edge, "Test Convolve");
+ cvWaitKey(0);
+}
+
void test_color()
{
image dog = load_image("test_color.png");
@@ -199,7 +229,7 @@
{
srand(444444);
srand(888888);
- network net = parse_network_cfg("nist.cfg");
+ network net = parse_network_cfg("nist_basic.cfg");
data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
normalize_data_rows(train);
@@ -216,8 +246,8 @@
end = clock();
printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
start=end;
- visualize_network(net);
- cvWaitKey(100);
+ //visualize_network(net);
+ //cvWaitKey(100);
//lr /= 2;
if(count%5 == 0 && 0){
double train_acc = network_accuracy(net, train);
@@ -336,13 +366,59 @@
printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows);
}
+double *random_matrix(int rows, int cols)
+{
+ int i, j;
+ double *m = calloc(rows*cols, sizeof(double));
+ for(i = 0; i < rows; ++i){
+ for(j = 0; j < cols; ++j){
+ m[i*cols+j] = (double)rand()/RAND_MAX;
+ }
+ }
+ return m;
+}
+
+void test_blas()
+{
+ int m = 6025, n = 20, k = 11*11*3;
+ double *a = random_matrix(m,k);
+ double *b = random_matrix(k,n);
+ double *c = random_matrix(m,n);
+ int i;
+ for(i = 0; i<1000; ++i){
+ gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+ }
+}
+
+void test_im2row()
+{
+ int h = 20;
+ int w = 20;
+ int c = 3;
+ int stride = 1;
+ int size = 11;
+ image test = make_random_image(h,w,c);
+ int mc = 1;
+ int mw = ((h-size)/stride+1)*((w-size)/stride+1);
+ int mh = (size*size*c);
+ int msize = mc*mw*mh;
+ double *matrix = calloc(msize, sizeof(double));
+ int i;
+ for(i = 0; i < 1000; ++i){
+ im2col_cpu(test.data, c, h, w, size, stride, matrix);
+ image render = double_to_image(mh, mw, mc, matrix);
+ }
+}
int main()
{
+ //test_blas();
+ test_convolve_matrix();
+// test_im2row();
//test_kernel_update();
//test_split();
//test_ensemble();
- test_nist();
+ //test_nist();
//test_full();
//test_random_preprocess();
//test_random_classify();
--
Gitblit v1.10.0