From 70d622ea54c55aa5489e71b769a92447a586c879 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 14 Jul 2014 05:07:51 +0000
Subject: [PATCH] Added batch to col2im, padding option
---
src/cnn.c | 42 ++++++++++++++++++++++++------------------
1 files changed, 24 insertions(+), 18 deletions(-)
diff --git a/src/tests.c b/src/cnn.c
similarity index 95%
rename from src/tests.c
rename to src/cnn.c
index 8105404..96b9463 100644
--- a/src/tests.c
+++ b/src/cnn.c
@@ -52,7 +52,7 @@
int i;
clock_t start = clock(), end;
for(i = 0; i < 1000; ++i){
- im2col_cpu(dog.data, 1, dog.c, dog.h, dog.w, size, stride, matrix);
+ im2col_cpu(dog.data, 1, dog.c, dog.h, dog.w, size, stride, 0, matrix);
gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
}
end = clock();
@@ -76,7 +76,7 @@
int size = 3;
float eps = .00000001;
image test = make_random_image(5,5, 1);
- convolutional_layer layer = *make_convolutional_layer(1,test.h,test.w,test.c, n, size, stride, RELU);
+ convolutional_layer layer = *make_convolutional_layer(1,test.h,test.w,test.c, n, size, stride, 0, RELU);
image out = get_convolutional_image(layer);
float **jacobian = calloc(test.h*test.w*test.c, sizeof(float));
@@ -301,7 +301,7 @@
void test_nist()
{
srand(444444);
- srand(888888);
+ srand(222222);
network net = parse_network_cfg("cfg/nist.cfg");
data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
@@ -309,22 +309,26 @@
normalize_data_rows(test);
//randomize_data(train);
int count = 0;
- float lr = .00005;
+ float lr = .000075;
float momentum = .9;
float decay = 0.0001;
decay = 0;
//clock_t start = clock(), end;
- int batch = 10000;
- while(++count <= 10000){
- float loss = train_network_sgd(net, train, batch, lr, momentum, decay);
+ int iters = 100;
+ while(++count <= 10){
+ clock_t start = clock(), end;
+ float loss = train_network_sgd(net, train, iters, lr, momentum, decay);
+ end = clock();
float test_acc = network_accuracy(net, test);
- printf("%3d %5f %5f\n",count, loss, test_acc);
+ printf("%d: %f %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay);
+
//printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay);
//end = clock();
//printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
//start=end;
//lr *= .5;
}
+ //save_network(net, "cfg/nist_basic_trained.cfg");
}
void test_ensemble()
@@ -431,7 +435,7 @@
float *matrix = calloc(msize, sizeof(float));
int i;
for(i = 0; i < 1000; ++i){
- im2col_cpu(test.data, 1, c, h, w, size, stride, matrix);
+ im2col_cpu(test.data, 1, c, h, w, size, stride, 0, matrix);
//image render = float_to_image(mh, mw, mc, matrix);
}
}
@@ -442,34 +446,36 @@
save_network(net, "cfg/voc_imagenet_rev.cfg");
}
-void train_VOC()
+void tune_VOC()
{
network net = parse_network_cfg("cfg/voc_start.cfg");
srand(2222222);
int i = 20;
char *labels[] = {"aeroplane","bicycle","bird","boat","bottle","bus","car","cat","chair","cow","diningtable","dog","horse","motorbike","person","pottedplant","sheep","sofa","train","tvmonitor"};
- float lr = .00001;
+ float lr = .000005;
float momentum = .9;
- float decay = 0.01;
+ float decay = 0.0001;
while(i++ < 1000 || 1){
- data train = load_data_image_pathfile_random("images/VOC2012/val_paths.txt", 1000, labels, 20, 300, 400);
+ data train = load_data_image_pathfile_random("/home/pjreddie/VOC2012/trainval_paths.txt", 10, labels, 20, 256, 256);
- image im = float_to_image(300, 400, 3,train.X.vals[0]);
+ image im = float_to_image(256, 256, 3,train.X.vals[0]);
show_image(im, "input");
visualize_network(net);
cvWaitKey(100);
- normalize_data_rows(train);
+ translate_data_rows(train, -144);
clock_t start = clock(), end;
- float loss = train_network_sgd(net, train, 1000, lr, momentum, decay);
+ float loss = train_network_sgd(net, train, 10, lr, momentum, decay);
end = clock();
printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay);
free_data(train);
+ /*
if(i%10==0){
char buff[256];
- sprintf(buff, "cfg/voc_clean_ramp_%d.cfg", i);
+ sprintf(buff, "/home/pjreddie/voc_cfg/voc_ramp_%d.cfg", i);
save_network(net, buff);
}
+ */
//lr *= .99;
}
}
@@ -778,7 +784,7 @@
//test_cifar10();
//test_vince();
//test_full();
- //train_VOC();
+ //tune_VOC();
//features_VOC_image(argv[1], argv[2], argv[3], 0);
//features_VOC_image(argv[1], argv[2], argv[3], 1);
//features_VOC_image_size(argv[1], atoi(argv[2]), atoi(argv[3]));
--
Gitblit v1.10.0