From 0948df52b850b908e7a74cb589d19fa29eb30368 Mon Sep 17 00:00:00 2001
From: Alexey <AlexeyAB@users.noreply.github.com>
Date: Tue, 08 May 2018 14:27:45 +0000
Subject: [PATCH] Merge pull request #741 from IlyaOvodov/Fix_detector_output
---
src/detector.c | 62 +++++++++++++++++++++++++------
1 files changed, 50 insertions(+), 12 deletions(-)
diff --git a/src/detector.c b/src/detector.c
index 407c8be..71ede10 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -1,3 +1,8 @@
+#ifdef _DEBUG
+#include <stdlib.h>
+#include <crtdbg.h>
+#endif
+
#include "network.h"
#include "region_layer.h"
#include "cost_layer.h"
@@ -86,6 +91,7 @@
args.n = imgs;
args.m = plist->size;
args.classes = classes;
+ args.flip = net.flip;
args.jitter = jitter;
args.num_boxes = l.max_boxes;
args.small_object = net.small_object;
@@ -644,6 +650,8 @@
truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
}
+ const int checkpoint_detections_count = detections_count;
+
for (i = 0; i < nboxes; ++i) {
int class_id;
@@ -694,7 +702,13 @@
// calc avg IoU, true-positives, false-positives for required Threshold
if (prob > thresh_calc_avg_iou) {
- if (truth_index > -1) {
+ int z, found = 0;
+ for (z = checkpoint_detections_count; z < detections_count-1; ++z)
+ if (detections[z].unique_truth_index == truth_index) {
+ found = 1; break;
+ }
+
+ if(truth_index > -1 && found == 0) {
avg_iou += max_iou;
++tp_for_thresh;
}
@@ -714,7 +728,8 @@
}
}
- avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
+ if((tp_for_thresh + fp_for_thresh) > 0)
+ avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
// SORT(detections)
@@ -1026,7 +1041,8 @@
}
#endif // OPENCV
-void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh, float hier_thresh, int dont_show)
+void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh,
+ float hier_thresh, int dont_show, int ext_output)
{
list *options = read_data_cfg(datacfg);
char *name_list = option_find_str(options, "names", "data/names.list");
@@ -1048,7 +1064,8 @@
while(1){
if(filename){
strncpy(input, filename, 256);
- if (input[strlen(input) - 1] == 0x0d) input[strlen(input) - 1] = 0;
+ if(strlen(input) > 0)
+ if (input[strlen(input) - 1] == 0x0d) input[strlen(input) - 1] = 0;
} else {
printf("Enter Image Path: ");
fflush(stdout);
@@ -1058,8 +1075,8 @@
}
image im = load_image_color(input,0,0);
int letterbox = 0;
- image sized = resize_image(im, net.w, net.h);
- //image sized = letterbox_image(im, net.w, net.h); letterbox = 1;
+ //image sized = resize_image(im, net.w, net.h);
+ image sized = letterbox_image(im, net.w, net.h); letterbox = 1;
layer l = net.layers[net.n-1];
//box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
@@ -1068,8 +1085,8 @@
float *X = sized.data;
time= what_time_is_it_now();
- //network_predict(net, X);
- network_predict_image(&net, im); letterbox = 1;
+ network_predict(net, X);
+ //network_predict_image(&net, im); letterbox = 1;
printf("%s: Predicted in %f seconds.\n", input, (what_time_is_it_now()-time));
//get_region_boxes(l, 1, 1, thresh, probs, boxes, 0, 0);
// if (nms) do_nms_sort_v2(boxes, probs, l.w*l.h*l.n, l.classes, nms);
@@ -1077,7 +1094,7 @@
int nboxes = 0;
detection *dets = get_network_boxes(&net, im.w, im.h, thresh, hier_thresh, 0, 1, &nboxes, letterbox);
if (nms) do_nms_sort(dets, nboxes, l.classes, nms);
- draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes);
+ draw_detections_v3(im, dets, nboxes, thresh, names, alphabet, l.classes, ext_output);
free_detections(dets, nboxes);
save_image(im, "predictions");
if (!dont_show) {
@@ -1096,6 +1113,22 @@
#endif
if (filename) break;
}
+
+ // free memory
+ free_ptrs(names, net.layers[net.n - 1].classes);
+ free_list(options);
+
+ int i;
+ const int nsize = 8;
+ for (j = 0; j < nsize; ++j) {
+ for (i = 32; i < 127; ++i) {
+ free_image(alphabet[j][i]);
+ }
+ free(alphabet[j]);
+ }
+ free(alphabet);
+
+ free_network(net);
}
void run_detector(int argc, char **argv)
@@ -1113,6 +1146,9 @@
int num_of_clusters = find_int_arg(argc, argv, "-num_of_clusters", 5);
int width = find_int_arg(argc, argv, "-width", -1);
int height = find_int_arg(argc, argv, "-height", -1);
+ // extended output in test mode (output of rect bound coords)
+ // and for recall mode (extended output table-like format with results for best_class fit)
+ int ext_output = find_arg(argc, argv, "-ext_output");
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
@@ -1146,9 +1182,10 @@
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
if(weights)
- if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
+ if(strlen(weights) > 0)
+ if (weights[strlen(weights) - 1] == 0x0d) weights[strlen(weights) - 1] = 0;
char *filename = (argc > 6) ? argv[6]: 0;
- if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show);
+ if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh, hier_thresh, dont_show, ext_output);
else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear, dont_show);
else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights, outfile);
else if(0==strcmp(argv[2], "recall")) validate_detector_recall(datacfg, cfg, weights);
@@ -1160,7 +1197,8 @@
char *name_list = option_find_str(options, "names", "data/names.list");
char **names = get_labels(name_list);
if(filename)
- if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
+ if(strlen(filename) > 0)
+ if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
demo(cfg, weights, thresh, hier_thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename,
http_stream_port, dont_show);
}
--
Gitblit v1.10.0