From 153705226d8ca746478b69eeac9bc854766daa11 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 27 Jan 2015 21:31:06 +0000
Subject: [PATCH] Bias updates bug fix

---
 src/cnn.c |   33 ++++++++++++++++++---------------
 1 files changed, 18 insertions(+), 15 deletions(-)

diff --git a/src/cnn.c b/src/cnn.c
index c3b7b2c..4f575dc 100644
--- a/src/cnn.c
+++ b/src/cnn.c
@@ -212,7 +212,8 @@
     //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
     srand(time(0));
     network net = parse_network_cfg(cfgfile);
-    set_learning_network(&net, net.learning_rate, net.momentum, net.decay);
+    //test_learn_bias(*(convolutional_layer *)net.layers[1]);
+    //set_learning_network(&net, net.learning_rate, 0, net.decay);
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int imgs = 3072;
     int i = net.seen/imgs;
@@ -383,25 +384,26 @@
     cvWaitKey(0);
 }
 
-void test_cifar10()
+void test_cifar10(char *cfgfile)
 {
-    network net = parse_network_cfg("cfg/cifar10_part5.cfg");
+    network net = parse_network_cfg(cfgfile);
     data test = load_cifar10_data("data/cifar10/test_batch.bin");
     clock_t start = clock(), end;
-    float test_acc = network_accuracy(net, test);
+    float test_acc = network_accuracy_multi(net, test, 10);
     end = clock();
-    printf("%f in %f Sec\n", test_acc, (float)(end-start)/CLOCKS_PER_SEC);
-    visualize_network(net);
-    cvWaitKey(0);
+    printf("%f in %f Sec\n", test_acc, sec(end-start));
+    //visualize_network(net);
+    //cvWaitKey(0);
 }
 
-void train_cifar10()
+void train_cifar10(char *cfgfile)
 {
     srand(555555);
-    network net = parse_network_cfg("cfg/cifar10.cfg");
+    srand(time(0));
+    network net = parse_network_cfg(cfgfile);
     data test = load_cifar10_data("data/cifar10/test_batch.bin");
     int count = 0;
-    int iters = 10000/net.batch;
+    int iters = 50000/net.batch;
     data train = load_all_cifar10();
     while(++count <= 10000){
         clock_t time = clock();
@@ -410,9 +412,9 @@
         if(count%10 == 0){
             float test_acc = network_accuracy(net, test);
             printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,sec(clock()-time));
-            //char buff[256];
-            //sprintf(buff, "unikitty/cifar10_%d.cfg", count);
-            //save_network(net, buff);
+            char buff[256];
+            sprintf(buff, "/home/pjreddie/imagenet_backup/cifar10_%d.cfg", count);
+            save_network(net, buff);
         }else{
             printf("%d: Loss: %f, Time: %lf seconds\n", count, loss, sec(clock()-time));
         }
@@ -709,8 +711,7 @@
     }
 #endif
 
-    if(0==strcmp(argv[1], "cifar")) train_cifar10();
-    else if(0==strcmp(argv[1], "test_correct")) test_correct_alexnet();
+    if(0==strcmp(argv[1], "test_correct")) test_correct_alexnet();
     else if(0==strcmp(argv[1], "test_correct_nist")) test_correct_nist();
     else if(0==strcmp(argv[1], "test")) test_imagenet();
     //else if(0==strcmp(argv[1], "server")) run_server();
@@ -724,7 +725,9 @@
         return 0;
     }
     else if(0==strcmp(argv[1], "detection")) train_detection_net(argv[2]);
+    else if(0==strcmp(argv[1], "ctrain")) train_cifar10(argv[2]);
     else if(0==strcmp(argv[1], "nist")) train_nist(argv[2]);
+    else if(0==strcmp(argv[1], "ctest")) test_cifar10(argv[2]);
     else if(0==strcmp(argv[1], "train")) train_imagenet(argv[2]);
     //else if(0==strcmp(argv[1], "client")) train_imagenet_distributed(argv[2]);
     else if(0==strcmp(argv[1], "detect")) test_detection(argv[2]);

--
Gitblit v1.10.0