From 65bff2683bdffe7ec82eacd8144c70c09d19c88d Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Fri, 16 Feb 2018 20:55:37 +0000
Subject: [PATCH] It takes into account the Difficult for calculating mAP for PascalVOC

---
 src/detector.c |   46 ++++++++++++++++++++++++++++++++++++++++------
 1 files changed, 40 insertions(+), 6 deletions(-)

diff --git a/src/detector.c b/src/detector.c
index 3111c19..ce259fd 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -499,9 +499,9 @@
 {
 	int j;
 	list *options = read_data_cfg(datacfg);
-	char *valid_images = option_find_str(options, "valid", "data/train.list");
+	char *valid_images = option_find_str(options, "valid", "data/train.txt");
+	char *difficult_valid_images = option_find_str(options, "difficult", NULL);
 	char *name_list = option_find_str(options, "names", "data/names.list");
-	//char *prefix = option_find_str(options, "results", "results");
 	char **names = get_labels(name_list);
 	char *mapf = option_find_str(options, "map", 0);
 	int *map = 0;
@@ -515,10 +515,16 @@
 	fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
 	srand(time(0));
 
-	char *base = "comp4_det_test_";
 	list *plist = get_paths(valid_images);
 	char **paths = (char **)list_to_array(plist);
 
+	char **paths_dif = NULL;
+	if (difficult_valid_images) {
+		list *plist_dif = get_paths(difficult_valid_images);
+		paths_dif = (char **)list_to_array(plist_dif);
+	}
+	
+
 	layer l = net.layers[net.n - 1];
 	int classes = l.classes;
 
@@ -574,7 +580,7 @@
 		}
 		for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
 			const int image_index = i + t - nthreads;
-			char *path = paths[i + t - nthreads];
+			char *path = paths[image_index];
 			char *id = basecfg(path);
 			float *X = val_resized[t].data;
 			network_predict(net, X);
@@ -594,6 +600,22 @@
 				truth_classes_count[truth[j].id]++;
 			}
 
+			// difficult
+			box_label *truth_dif = NULL;
+			int num_labels_dif = 0;
+			if (paths_dif)
+			{
+				char *path_dif = paths_dif[image_index];
+
+				char labelpath_dif[4096];
+				find_replace(path_dif, "images", "labels", labelpath_dif);
+				find_replace(labelpath_dif, "JPEGImages", "labels", labelpath_dif);
+				find_replace(labelpath_dif, ".jpg", ".txt", labelpath_dif);
+				find_replace(labelpath_dif, ".JPEG", ".txt", labelpath_dif);
+				find_replace(labelpath_dif, ".png", ".txt", labelpath_dif);				
+				truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
+			}
+
 			for (i = 0; i < (l.w*l.h*l.n); ++i) {
 
 				int class_id;
@@ -606,6 +628,8 @@
 						detections[detections_count - 1].p = prob;
 						detections[detections_count - 1].image_index = image_index;
 						detections[detections_count - 1].class_id = class_id;
+						detections[detections_count - 1].truth_flag = 0;
+						detections[detections_count - 1].unique_truth_index = -1;
 
 						int truth_index = -1;
 						float max_iou = 0;
@@ -617,16 +641,27 @@
 							float current_iou = box_iou(boxes[i], t);
 							if (current_iou > iou_thresh && class_id == truth[j].id) {
 								if (current_iou > max_iou) {
-									current_iou = max_iou;
+									max_iou = current_iou;
 									truth_index = unique_truth_index + j;
 								}
 							}
 						}
+
 						// best IoU
 						if (truth_index > -1) {
 							detections[detections_count - 1].truth_flag = 1;
 							detections[detections_count - 1].unique_truth_index = truth_index;
 						}
+						else {
+							// if object is difficult then remove detection
+							for (j = 0; j < num_labels_dif; ++j) {
+								box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h };
+								float current_iou = box_iou(boxes[i], t);
+								if (current_iou > iou_thresh && class_id == truth_dif[j].id) {
+									--detections_count;
+								}
+							}
+						}
 					}
 				}
 			}
@@ -685,7 +720,6 @@
 			pr[d.class_id][rank].fp++;	// false-positive
 		}
 
-
 		for (i = 0; i < classes; ++i) 
 		{
 			const int tp = pr[i][rank].tp;

--
Gitblit v1.10.0