From c7a700dc2249e8bd3a2c9120dfd09240e413c8bd Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sat, 05 Nov 2016 21:09:21 +0000
Subject: [PATCH] new font strategy

---
 src/detector.c |   65 ++++++++++++++++++--------------
 1 files changed, 37 insertions(+), 28 deletions(-)

diff --git a/src/detector.c b/src/detector.c
index 1f48c61..e020be5 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -5,17 +5,18 @@
 #include "parser.h"
 #include "box.h"
 #include "demo.h"
+#include "option_list.h"
 
 #ifdef OPENCV
 #include "opencv2/highgui/highgui_c.h"
 #endif
 
-static char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
-
-void train_detector(char *cfgfile, char *weightfile)
+void train_detector(char *datacfg, char *cfgfile, char *weightfile, int clear)
 {
-    char *train_images = "/data/voc/train.txt";
-    char *backup_directory = "/home/pjreddie/backup/";
+    list *options = read_data_cfg(datacfg);
+    char *train_images = option_find_str(options, "train", "data/train.list");
+    char *backup_directory = option_find_str(options, "backup", "/backup/");
+
     srand(time(0));
     char *base = basecfg(cfgfile);
     printf("%s\n", base);
@@ -24,6 +25,7 @@
     if(weightfile){
         load_weights(&net, weightfile);
     }
+    if(clear) *net.seen = 0;
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int imgs = net.batch*net.subdivisions;
     int i = *net.seen/imgs;
@@ -124,8 +126,13 @@
     }
 }
 
-void validate_detector(char *cfgfile, char *weightfile)
+void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
 {
+    list *options = read_data_cfg(datacfg);
+    char *valid_images = option_find_str(options, "valid", "data/train.list");
+    char *name_list = option_find_str(options, "names", "data/names.list");
+    char **names = get_labels(name_list);
+
     network net = parse_network_cfg(cfgfile);
     if(weightfile){
         load_weights(&net, weightfile);
@@ -135,9 +142,7 @@
     srand(time(0));
 
     char *base = "results/comp4_det_test_";
-    //list *plist = get_paths("data/voc.2007.test");
-    list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt");
-    //list *plist = get_paths("data/voc.2012.test");
+    list *plist = get_paths(valid_images);
     char **paths = (char **)list_to_array(plist);
 
     layer l = net.layers[net.n-1];
@@ -147,7 +152,7 @@
     FILE **fps = calloc(classes, sizeof(FILE *));
     for(j = 0; j < classes; ++j){
         char buff[1024];
-        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
+        snprintf(buff, 1024, "%s%s.txt", base, names[j]);
         fps[j] = fopen(buff, "w");
     }
     box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
@@ -224,7 +229,6 @@
     fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     srand(time(0));
 
-    char *base = "results/comp4_det_test_";
     list *plist = get_paths("data/voc.2007.test");
     char **paths = (char **)list_to_array(plist);
 
@@ -232,12 +236,6 @@
     int classes = l.classes;
 
     int j, k;
-    FILE **fps = calloc(classes, sizeof(FILE *));
-    for(j = 0; j < classes; ++j){
-        char buff[1024];
-        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
-        fps[j] = fopen(buff, "w");
-    }
     box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
     float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
     for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
@@ -299,9 +297,13 @@
     }
 }
 
-void test_detector(char *cfgfile, char *weightfile, char *filename, float thresh)
+void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh)
 {
-    image *alphabet = load_alphabet();
+    list *options = read_data_cfg(datacfg);
+        char *name_list = option_find_str(options, "names", "data/names.list");
+        char **names = get_labels(name_list);
+
+    image **alphabet = load_alphabet();
     network net = parse_network_cfg(cfgfile);
     if(weightfile){
         load_weights(&net, weightfile);
@@ -335,8 +337,7 @@
         printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
         get_region_boxes(l, 1, 1, thresh, probs, boxes, 0);
         if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms);
-        //draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);
-        draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, voc_names, alphabet, 20);
+        draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes);
         save_image(im, "predictions");
         show_image(im, "predictions");
 
@@ -360,13 +361,21 @@
         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
         return;
     }
+    int clear = find_arg(argc, argv, "-clear");
 
-    char *cfg = argv[3];
-    char *weights = (argc > 4) ? argv[4] : 0;
-    char *filename = (argc > 5) ? argv[5]: 0;
-    if(0==strcmp(argv[2], "test")) test_detector(cfg, weights, filename, thresh);
-    else if(0==strcmp(argv[2], "train")) train_detector(cfg, weights);
-    else if(0==strcmp(argv[2], "valid")) validate_detector(cfg, weights);
+    char *datacfg = argv[3];
+    char *cfg = argv[4];
+    char *weights = (argc > 5) ? argv[5] : 0;
+    char *filename = (argc > 6) ? argv[6]: 0;
+    if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh);
+    else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, clear);
+    else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights);
     else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
-    else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix);
+    else if(0==strcmp(argv[2], "demo")) {
+        list *options = read_data_cfg(datacfg);
+        int classes = option_find_int(options, "classes", 20);
+        char *name_list = option_find_str(options, "names", "data/names.list");
+        char **names = get_labels(name_list);
+        demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix);
+    }
 }

--
Gitblit v1.10.0