From ad9dbfe16495204453b1b7f8593d320751f76ca0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 10 Dec 2013 18:30:42 +0000
Subject: [PATCH] CSE546 submission

---
 src/tests.c |   46 ++++++++++++++++++++++++++--------------------
 1 files changed, 26 insertions(+), 20 deletions(-)

diff --git a/src/tests.c b/src/tests.c
index 0b9b5db..4638645 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -195,26 +195,31 @@
 void test_nist()
 {
     srand(444444);
+    srand(888888);
     network net = parse_network_cfg("nist.cfg");
     data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
     data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
     normalize_data_rows(train);
     normalize_data_rows(test);
-    randomize_data(train);
+    //randomize_data(train);
     int count = 0;
     double lr = .0005;
-    while(++count <= 1){
-        double acc = train_network_sgd(net, train, 10000, lr, .9, .001);
-        printf("Training Accuracy: %lf\n", acc);
-        lr /= 2; 
+    double momentum = .9;
+    double decay = 0.01;
+    while(++count <= 1000){
+        double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
+        printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
+        visualize_network(net);
+        cvWaitKey(100);
+        //lr /= 2; 
+        if(count%5 == 0 && 0){
+            double train_acc = network_accuracy(net, train);
+            fprintf(stderr, "\nTRAIN: %f\n", train_acc);
+            double test_acc = network_accuracy(net, test);
+            fprintf(stderr, "TEST: %f\n\n", test_acc);
+            printf("%d, %f, %f\n", count, train_acc, test_acc);
+        }
     }
-    double train_acc = network_accuracy(net, train);
-    fprintf(stderr, "\nTRAIN: %f\n", train_acc);
-    double test_acc = network_accuracy(net, test);
-    fprintf(stderr, "TEST: %f\n\n", test_acc);
-    printf("%d, %f, %f\n", count, train_acc, test_acc);
-    //end = clock();
-    //printf("Neural Net Learning: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
 }
 
 void test_ensemble()
@@ -223,24 +228,25 @@
     srand(888888);
     data d = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
     normalize_data_rows(d);
-    randomize_data(d);
     data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10);
     normalize_data_rows(test);
     data train = d;
     /*
-    data *split = split_data(d, 1, 10);
-    data train = split[0];
-    data test = split[1];
-    */
+       data *split = split_data(d, 1, 10);
+       data train = split[0];
+       data test = split[1];
+     */
     matrix prediction = make_matrix(test.y.rows, test.y.cols);
     int n = 30;
     for(i = 0; i < n; ++i){
         int count = 0;
         double lr = .0005;
+        double momentum = .9;
+        double decay = .01;
         network net = parse_network_cfg("nist.cfg");
-        while(++count <= 5){
-            double acc = train_network_sgd(net, train, train.X.rows, lr, .9, .001);
-            printf("Training Accuracy: %lf\n", acc);
+        while(++count <= 15){
+            double acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
+            printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay );
             lr /= 2; 
         }
         matrix partial = network_predict_data(net, test);

--
Gitblit v1.10.0