From 1d53b6414e0cd81043d7c76aa89f4f97da5e479f Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 23 Jan 2014 19:24:37 +0000
Subject: [PATCH] Stable on MNIST, about to change a lot

---
 src/network.c |   18 ++++++++++++++++++
 full.cfg      |    4 ----
 src/network.h |    1 +
 nist.cfg      |    4 ++--
 src/tests.c   |   17 ++++++++++++-----
 5 files changed, 33 insertions(+), 11 deletions(-)

diff --git a/full.cfg b/full.cfg
index a18da17..78e938f 100644
--- a/full.cfg
+++ b/full.cfg
@@ -11,10 +11,6 @@
 stride=2
 
 [conn]
-output = 100
-activation=ramp
-
-[conn]
 output = 2
 activation=ramp
 
diff --git a/nist.cfg b/nist.cfg
index 5b0541c..46e3223 100644
--- a/nist.cfg
+++ b/nist.cfg
@@ -2,7 +2,7 @@
 width=28
 height=28
 channels=1
-filters=5
+filters=20
 size=5
 stride=1
 activation=ramp
@@ -20,7 +20,7 @@
 stride=2
 
 [conn]
-output = 100
+output = 500
 activation=ramp
 
 [conn]
diff --git a/src/network.c b/src/network.c
index 10ad110..07ac621 100644
--- a/src/network.c
+++ b/src/network.c
@@ -187,6 +187,24 @@
     }
     return (double)correct/n;
 }
+double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
+{
+    int i;
+    int correct = 0;
+    for(i = 0; i < n; ++i){
+        int index = rand()%d.X.rows;
+        double *x = d.X.vals[index];
+        double *y = d.y.vals[index];
+        forward_network(net, x);
+        int class = get_predicted_class_network(net);
+        backward_network(net, x, y);
+        correct += (y[class]?1:0);
+    }
+    update_network(net, step, momentum, decay);
+    return (double)correct/n;
+
+}
+
 
 void train_network(network net, data d, double step, double momentum, double decay)
 {
diff --git a/src/network.h b/src/network.h
index 2ffc76b..975c3dd 100644
--- a/src/network.h
+++ b/src/network.h
@@ -25,6 +25,7 @@
 void backward_network(network net, double *input, double *truth);
 void update_network(network net, double step, double momentum, double decay);
 double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
+double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
 void train_network(network net, data d, double step, double momentum, double decay);
 matrix network_predict_data(network net, data test);
 double network_accuracy(network net, data d);
diff --git a/src/tests.c b/src/tests.c
index 4638645..2a50bac 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -184,9 +184,12 @@
     srand(0);
     int i = 0;
     char *labels[] = {"cat","dog"};
+    double lr = .00001;
+    double momentum = .9;
+    double decay = 0.01;
     while(i++ < 1000 || 1){
         data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
-        train_network(net, train, .0005, 0, 0);
+        train_network(net, train, lr, momentum, decay);
         free_data(train);
         printf("Round %d\n", i);
     }
@@ -206,9 +209,13 @@
     double lr = .0005;
     double momentum = .9;
     double decay = 0.01;
+    clock_t start = clock(), end;
     while(++count <= 1000){
-        double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
-        printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
+        double acc = train_network_sgd(net, train, 6400, lr, momentum, decay);
+        printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay);
+        end = clock();
+        printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+        start=end;
         visualize_network(net);
         cvWaitKey(100);
         //lr /= 2; 
@@ -334,8 +341,8 @@
 {
     //test_kernel_update();
     //test_split();
-    test_ensemble();
-    //test_nist();
+    //test_ensemble();
+    test_nist();
     //test_full();
     //test_random_preprocess();
     //test_random_classify();

--
Gitblit v1.10.0