AlexeyAB
2018-02-16 baf3fcb491ee1a5e083fbdfdf2c28aaf36488e92
src/detector.c
@@ -499,9 +499,9 @@
{
   int j;
   list *options = read_data_cfg(datacfg);
   char *valid_images = option_find_str(options, "valid", "data/train.list");
   char *valid_images = option_find_str(options, "valid", "data/train.txt");
   char *difficult_valid_images = option_find_str(options, "difficult", NULL);
   char *name_list = option_find_str(options, "names", "data/names.list");
   //char *prefix = option_find_str(options, "results", "results");
   char **names = get_labels(name_list);
   char *mapf = option_find_str(options, "map", 0);
   int *map = 0;
@@ -515,10 +515,16 @@
   fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
   srand(time(0));
   char *base = "comp4_det_test_";
   list *plist = get_paths(valid_images);
   char **paths = (char **)list_to_array(plist);
   char **paths_dif = NULL;
   if (difficult_valid_images) {
      list *plist_dif = get_paths(difficult_valid_images);
      paths_dif = (char **)list_to_array(plist_dif);
   }
   layer l = net.layers[net.n - 1];
   int classes = l.classes;
@@ -546,9 +552,14 @@
   args.h = net.h;
   args.type = IMAGE_DATA;
   const float thresh_calc_avg_iou = 0.24;
   float avg_iou = 0;
   int tp_for_thresh = 0;
   int fp_for_thresh = 0;
   box_prob *detections = calloc(1, sizeof(box_prob));
   int detections_count = 0;
   int unique_truth_index = 0;
   int unique_truth_count = 0;
   int *truth_classes_count = calloc(classes, sizeof(int));
@@ -574,7 +585,7 @@
      }
      for (t = 0; t < nthreads && i + t - nthreads < m; ++t) {
         const int image_index = i + t - nthreads;
         char *path = paths[i + t - nthreads];
         char *path = paths[image_index];
         char *id = basecfg(path);
         float *X = val_resized[t].data;
         network_predict(net, X);
@@ -594,6 +605,22 @@
            truth_classes_count[truth[j].id]++;
         }
         // difficult
         box_label *truth_dif = NULL;
         int num_labels_dif = 0;
         if (paths_dif)
         {
            char *path_dif = paths_dif[image_index];
            char labelpath_dif[4096];
            find_replace(path_dif, "images", "labels", labelpath_dif);
            find_replace(labelpath_dif, "JPEGImages", "labels", labelpath_dif);
            find_replace(labelpath_dif, ".jpg", ".txt", labelpath_dif);
            find_replace(labelpath_dif, ".JPEG", ".txt", labelpath_dif);
            find_replace(labelpath_dif, ".png", ".txt", labelpath_dif);
            truth_dif = read_boxes(labelpath_dif, &num_labels_dif);
         }
         for (i = 0; i < (l.w*l.h*l.n); ++i) {
            int class_id;
@@ -606,6 +633,8 @@
                  detections[detections_count - 1].p = prob;
                  detections[detections_count - 1].image_index = image_index;
                  detections[detections_count - 1].class_id = class_id;
                  detections[detections_count - 1].truth_flag = 0;
                  detections[detections_count - 1].unique_truth_index = -1;
                  int truth_index = -1;
                  float max_iou = 0;
@@ -617,21 +646,43 @@
                     float current_iou = box_iou(boxes[i], t);
                     if (current_iou > iou_thresh && class_id == truth[j].id) {
                        if (current_iou > max_iou) {
                           current_iou = max_iou;
                           truth_index = unique_truth_index + j;
                           max_iou = current_iou;
                           truth_index = unique_truth_count + j;
                        }
                     }
                  }
                  // best IoU
                  if (truth_index > -1) {
                     detections[detections_count - 1].truth_flag = 1;
                     detections[detections_count - 1].unique_truth_index = truth_index;
                  }
                  else {
                     // if object is difficult then remove detection
                     for (j = 0; j < num_labels_dif; ++j) {
                        box t = { truth_dif[j].x, truth_dif[j].y, truth_dif[j].w, truth_dif[j].h };
                        float current_iou = box_iou(boxes[i], t);
                        if (current_iou > iou_thresh && class_id == truth_dif[j].id) {
                           --detections_count;
                           break;
                        }
                     }
                  }
                  // calc avg IoU, true-positives, false-positives for required Threshold
                  if (prob > thresh_calc_avg_iou) {
                     if (truth_index > -1) {
                        avg_iou += max_iou;
                        ++tp_for_thresh;
                     }
                     else
                        fp_for_thresh++;
                  }
               }
            }
         }
         
         unique_truth_index += num_labels;
         unique_truth_count += num_labels;
         free(id);
         free_image(val[t]);
@@ -639,6 +690,8 @@
      }
   }
   avg_iou = avg_iou / (tp_for_thresh + fp_for_thresh);
   
   // SORT(detections)
   qsort(detections, detections_count, sizeof(box_prob), detections_comparator);
@@ -654,10 +707,10 @@
   for (i = 0; i < classes; ++i) {
      pr[i] = calloc(detections_count, sizeof(pr_t));
   }
   printf("detections_count = %d, unique_truth_index = %d  \n", detections_count, unique_truth_index);
   printf("detections_count = %d, unique_truth_count = %d  \n", detections_count, unique_truth_count);
   int *truth_flags = calloc(unique_truth_index, sizeof(int));
   int *truth_flags = calloc(unique_truth_count, sizeof(int));
   int rank;
   for (rank = 0; rank < detections_count; ++rank) {
@@ -685,7 +738,6 @@
         pr[d.class_id][rank].fp++; // false-positive
      }
      for (i = 0; i < classes; ++i) 
      {
         const int tp = pr[i][rank].tp;
@@ -729,6 +781,9 @@
      mean_average_precision += avg_precision;
   }
   
   printf(" for thresh = %0.2f, TP = %d, FP = %d, FN = %d, average IoU = %2.2f %% \n",
      thresh_calc_avg_iou, tp_for_thresh, fp_for_thresh, unique_truth_count - tp_for_thresh, avg_iou * 100);
   mean_average_precision = mean_average_precision / classes;
   printf("\n mean average precision (mAP) = %f, or %2.2f %% \n", mean_average_precision, mean_average_precision*100);