From 252e3b1916cfaca0783c9e90efaa55eb07b1a8cd Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 06 Nov 2016 00:27:31 +0000
Subject: [PATCH] :charizard: :charizard: :charizard:
---
src/detector.c | 65 ++++++++++++++++++--------------
1 files changed, 37 insertions(+), 28 deletions(-)
diff --git a/src/detector.c b/src/detector.c
index 1f48c61..e020be5 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -5,17 +5,18 @@
#include "parser.h"
#include "box.h"
#include "demo.h"
+#include "option_list.h"
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#endif
-static char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
-
-void train_detector(char *cfgfile, char *weightfile)
+void train_detector(char *datacfg, char *cfgfile, char *weightfile, int clear)
{
- char *train_images = "/data/voc/train.txt";
- char *backup_directory = "/home/pjreddie/backup/";
+ list *options = read_data_cfg(datacfg);
+ char *train_images = option_find_str(options, "train", "data/train.list");
+ char *backup_directory = option_find_str(options, "backup", "/backup/");
+
srand(time(0));
char *base = basecfg(cfgfile);
printf("%s\n", base);
@@ -24,6 +25,7 @@
if(weightfile){
load_weights(&net, weightfile);
}
+ if(clear) *net.seen = 0;
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = net.batch*net.subdivisions;
int i = *net.seen/imgs;
@@ -124,8 +126,13 @@
}
}
-void validate_detector(char *cfgfile, char *weightfile)
+void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
{
+ list *options = read_data_cfg(datacfg);
+ char *valid_images = option_find_str(options, "valid", "data/train.list");
+ char *name_list = option_find_str(options, "names", "data/names.list");
+ char **names = get_labels(name_list);
+
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
@@ -135,9 +142,7 @@
srand(time(0));
char *base = "results/comp4_det_test_";
- //list *plist = get_paths("data/voc.2007.test");
- list *plist = get_paths("/home/pjreddie/data/voc/2007_test.txt");
- //list *plist = get_paths("data/voc.2012.test");
+ list *plist = get_paths(valid_images);
char **paths = (char **)list_to_array(plist);
layer l = net.layers[net.n-1];
@@ -147,7 +152,7 @@
FILE **fps = calloc(classes, sizeof(FILE *));
for(j = 0; j < classes; ++j){
char buff[1024];
- snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
+ snprintf(buff, 1024, "%s%s.txt", base, names[j]);
fps[j] = fopen(buff, "w");
}
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
@@ -224,7 +229,6 @@
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
- char *base = "results/comp4_det_test_";
list *plist = get_paths("data/voc.2007.test");
char **paths = (char **)list_to_array(plist);
@@ -232,12 +236,6 @@
int classes = l.classes;
int j, k;
- FILE **fps = calloc(classes, sizeof(FILE *));
- for(j = 0; j < classes; ++j){
- char buff[1024];
- snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
- fps[j] = fopen(buff, "w");
- }
box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
@@ -299,9 +297,13 @@
}
}
-void test_detector(char *cfgfile, char *weightfile, char *filename, float thresh)
+void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh)
{
- image *alphabet = load_alphabet();
+ list *options = read_data_cfg(datacfg);
+ char *name_list = option_find_str(options, "names", "data/names.list");
+ char **names = get_labels(name_list);
+
+ image **alphabet = load_alphabet();
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
@@ -335,8 +337,7 @@
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
get_region_boxes(l, 1, 1, thresh, probs, boxes, 0);
if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, l.classes, nms);
- //draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, voc_names, voc_labels, 20);
- draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, voc_names, alphabet, 20);
+ draw_detections(im, l.w*l.h*l.n, thresh, boxes, probs, names, alphabet, l.classes);
save_image(im, "predictions");
show_image(im, "predictions");
@@ -360,13 +361,21 @@
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
}
+ int clear = find_arg(argc, argv, "-clear");
- char *cfg = argv[3];
- char *weights = (argc > 4) ? argv[4] : 0;
- char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "test")) test_detector(cfg, weights, filename, thresh);
- else if(0==strcmp(argv[2], "train")) train_detector(cfg, weights);
- else if(0==strcmp(argv[2], "valid")) validate_detector(cfg, weights);
+ char *datacfg = argv[3];
+ char *cfg = argv[4];
+ char *weights = (argc > 5) ? argv[5] : 0;
+ char *filename = (argc > 6) ? argv[6]: 0;
+ if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh);
+ else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, clear);
+ else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
- else if(0==strcmp(argv[2], "demo")) demo(cfg, weights, thresh, cam_index, filename, voc_names, 20, frame_skip, prefix);
+ else if(0==strcmp(argv[2], "demo")) {
+ list *options = read_data_cfg(datacfg);
+ int classes = option_find_int(options, "classes", 20);
+ char *name_list = option_find_str(options, "names", "data/names.list");
+ char **names = get_labels(name_list);
+ demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix);
+ }
}
--
Gitblit v1.10.0