From e96a454ca11f140a7f7fb82daefe4cc9555a0f26 Mon Sep 17 00:00:00 2001
From: Alexey <AlexeyAB@users.noreply.github.com>
Date: Sun, 25 Feb 2018 22:13:55 +0000
Subject: [PATCH] Update Readme.md

---
 src/detector.c |   80 ++++++++++++++++++++++++++++++++++------
 1 files changed, 68 insertions(+), 12 deletions(-)

diff --git a/src/detector.c b/src/detector.c
index 3111c19..c851b38 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -499,9 +499,9 @@
 {
 	int j;
 	list *options = read_data_cfg(datacfg);
-	char *valid_images = option_find_str(options, "valid", "data/train.list");
+	char *valid_images = option_find_str(options, "valid", "data/train.txt");
+	char *difficult_valid_images = option_find_str(options, "difficult", NULL);
 	char *name_list = option_find_str(options, "names", "data/names.list");
-	//char *prefix = option_find_str(options, "results", "results");
 	char **names = get_labels(name_list);
 	char *mapf = option_find_str(options, "map", 0);
 	int *map = 0;
@@ -515,10 +515,16 @@
 	fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
 	srand(time(0));
 
-	char *base = "comp4_det_test_";
 	list *plist = get_paths(valid_images);
 	char **paths = (char **)list_to_array(plist);
 
+	char **paths_dif = NULL;
+	if (difficult_valid_images) {
+		list *plist_dif = get_paths(difficult_valid_images);
+		paths_dif = (char **)list_to_array(plist_dif);
+	}
+	
+
 	layer l = net.layers[net.n - 1];
 	int classes = l.classes;
 
@@ -546,9 +552,14 @@
 	args.h = net.h;
 	args.type = IMAGE_DATA;
 
+	const float thresh_calc_avg_iou = 0.24;
+	float avg_iou = 0;
+	int tp_for_thresh = 0;
+	int fp_for_thresh = 0;
+
 	box_prob *detections = calloc(1, sizeof(box_prob));
 	int detections_count = 0;
-	int unique_truth_index = 0;
+	int unique_truth_count = 0;
 
 	int *truth_classes_count = calloc(classes, sizeof(int));
 
@@ -574,7 +585,7 @@
 		}
 		for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
 			const int image_index = i + t - nthreads;
-			char *path = paths[i + t - nthreads];
+			char *path = paths[image_index];
 			char *id = basecfg(path);
 			float *X = val_resized[t].data;
 			network_predict(net, X);
@@ -594,6 +605,22 @@
 				truth_classes_count[truth[j].id]++;
 			}
 
+			// difficult
+			box_label *truth_dif = NULL;
+			int num_labels_dif = 0;
+			if (paths_dif)
+			{
+				char *path_dif = paths_dif[image_index];
+
+				char labelpath_dif[4096];
+				find_replace(path_dif, "images", "labels", labelpath_dif);
+				find_replace(labelpath_dif, "JPEGImages", "labels", labelpath_dif);
+				find_replace(labelpath_dif, ".jpg", ".txt", labelpath_dif);
+				find_replace(labelpath_dif, ".JPEG", ".txt", labelpath_dif);
+				find_replace(labelpath_dif, ".png", ".txt", labelpath_dif);				
+				truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
+			}
+
 			for (i = 0; i < (l.w*l.h*l.n); ++i) {
 
 				int class_id;
@@ -606,6 +633,8 @@
 						detections[detections_count - 1].p = prob;
 						detections[detections_count - 1].image_index = image_index;
 						detections[detections_count - 1].class_id = class_id;
+						detections[detections_count - 1].truth_flag = 0;
+						detections[detections_count - 1].unique_truth_index = -1;
 
 						int truth_index = -1;
 						float max_iou = 0;
@@ -617,21 +646,43 @@
 							float current_iou = box_iou(boxes[i], t);
 							if (current_iou > iou_thresh && class_id == truth[j].id) {
 								if (current_iou > max_iou) {
-									current_iou = max_iou;
-									truth_index = unique_truth_index + j;
+									max_iou = current_iou;
+									truth_index = unique_truth_count + j;
 								}
 							}
 						}
+
 						// best IoU
 						if (truth_index > -1) {
 							detections[detections_count - 1].truth_flag = 1;
 							detections[detections_count - 1].unique_truth_index = truth_index;
 						}
+						else {
+							// if object is difficult then remove detection
+							for (j = 0; j < num_labels_dif; ++j) {
+								box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h };
+								float current_iou = box_iou(boxes[i], t);
+								if (current_iou > iou_thresh && class_id == truth_dif[j].id) {
+									--detections_count;
+									break;
+								}
+							}
+						}
+
+						// calc avg IoU, true-positives, false-positives for required Threshold
+						if (prob > thresh_calc_avg_iou) {
+							if (truth_index > -1) {
+								avg_iou += max_iou;
+								++tp_for_thresh;
+							}
+							else
+								fp_for_thresh++;
+						}
 					}
 				}
 			}
 			
-			unique_truth_index += num_labels;
+			unique_truth_count += num_labels;
 
 			free(id);
 			free_image(val[t]);
@@ -639,6 +690,8 @@
 		}
 	}
 
+	avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
+
 	
 	// SORT(detections)
 	qsort(detections, detections_count, sizeof(box_prob), detections_comparator);
@@ -654,10 +707,10 @@
 	for (i = 0; i < classes; ++i) {
 		pr[i] = calloc(detections_count, sizeof(pr_t));
 	}
-	printf("detections_count = %d, unique_truth_index = %d  \n", detections_count, unique_truth_index);
+	printf("detections_count = %d, unique_truth_count = %d  \n", detections_count, unique_truth_count);
 
 
-	int *truth_flags = calloc(unique_truth_index, sizeof(int));
+	int *truth_flags = calloc(unique_truth_count, sizeof(int));
 
 	int rank;
 	for (rank = 0; rank < detections_count; ++rank) {
@@ -685,7 +738,6 @@
 			pr[d.class_id][rank].fp++;	// false-positive
 		}
 
-
 		for (i = 0; i < classes; ++i) 
 		{
 			const int tp = pr[i][rank].tp;
@@ -729,6 +781,9 @@
 		mean_average_precision += avg_precision;
 	}
 	
+	printf(" for thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n", 
+		thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100);
+
 	mean_average_precision = mean_average_precision / classes;
 	printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision*100);
 
@@ -804,6 +859,7 @@
 
 void run_detector(int argc, char **argv)
 {
+	int http_stream_port = find_int_arg(argc, argv, "-http_port", -1);
 	char *out_filename = find_char_arg(argc, argv, "-out_filename", 0);
     char *prefix = find_char_arg(argc, argv, "-prefix", 0);
     float thresh = find_float_arg(argc, argv, "-thresh", .24);
@@ -856,6 +912,6 @@
         char **names = get_labels(name_list);
 		if(filename)
 			if (filename[strlen(filename) - 1] == 0x0d) filename[strlen(filename) - 1] = 0;
-        demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename);
+        demo(cfg, weights, thresh, cam_index, filename, names, classes, frame_skip, prefix, out_filename, http_stream_port);
     }
 }

--
Gitblit v1.10.0