From 11c72b1132feca7c1252ea01d02da4cb497e723f Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 11 Jun 2015 22:38:58 +0000
Subject: [PATCH] testing on one image

---
 src/imagenet.c |   32 +++++++++++++++++++-------------
 1 files changed, 19 insertions(+), 13 deletions(-)

diff --git a/src/imagenet.c b/src/imagenet.c
index 3f88b36..a564f33 100644
--- a/src/imagenet.c
+++ b/src/imagenet.c
@@ -14,6 +14,7 @@
         load_weights(&net, weightfile);
     }
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+    //net.seen=0;
     int imgs = 1024;
     int i = net.seen/imgs;
     char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
@@ -31,7 +32,7 @@
         pthread_join(load_thread, 0);
         train = buffer;
 
-/*
+        /*
         image im = float_to_image(256, 256, 3, train.X.vals[114]);
         show_image(im, "training");
         cvWaitKey(0);
@@ -46,6 +47,7 @@
         avg_loss = avg_loss*.9 + loss*.1;
         printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
         free_data(train);
+        if((i % 15000) == 0) net.learning_rate *= .1;
         //if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97;
         if(i%1000==0){
             char buff[256];
@@ -99,37 +101,40 @@
     }
 }
 
-void test_imagenet(char *cfgfile, char *weightfile)
+void test_imagenet(char *cfgfile, char *weightfile, char *filename)
 {
     network net = parse_network_cfg(cfgfile);
     if(weightfile){
         load_weights(&net, weightfile);
     }
     set_batch_network(&net, 1);
-    //imgs=1;
     srand(2222222);
     int i = 0;
-    char **names = get_labels("cfg/shortnames.txt");
+    char **names = get_labels("data/shortnames.txt");
     clock_t time;
-    char filename[256];
+    char input[256];
     int indexes[10];
     while(1){
-        fgets(filename, 256, stdin);
-        strtok(filename, "\n");
-        image im = load_image_color(filename, 256, 256);
-        scale_image(im, 2.);
-        translate_image(im, -1.);
-        printf("%d %d %d\n", im.h, im.w, im.c);
+        if(filename){
+            strncpy(input, filename, 256);
+        }else{
+            printf("Enter Image Path: ");
+            fflush(stdout);
+            fgets(input, 256, stdin);
+            strtok(input, "\n");
+        }
+        image im = load_image_color(input, 256, 256);
         float *X = im.data;
         time=clock();
         float *predictions = network_predict(net, X);
         top_predictions(net, 10, indexes);
-        printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time));
+        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
         for(i = 0; i < 10; ++i){
             int index = indexes[i];
             printf("%s: %f\n", names[index], predictions[index]);
         }
         free_image(im);
+        if (filename) break;
     }
 }
 
@@ -142,7 +147,8 @@
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
-    if(0==strcmp(argv[2], "test")) test_imagenet(cfg, weights);
+    char *filename = (argc > 5) ? argv[5]: 0;
+    if(0==strcmp(argv[2], "test")) test_imagenet(cfg, weights, filename);
     else if(0==strcmp(argv[2], "train")) train_imagenet(cfg, weights);
     else if(0==strcmp(argv[2], "valid")) validate_imagenet(cfg, weights);
 }

--
Gitblit v1.10.0