From b8e6e80c6d411d05a9e09f1e3676eb9a7f3ea0e8 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Fri, 03 Aug 2018 11:35:03 +0000
Subject: [PATCH] Added spatial Yolo v3 yolov3-spp.cfg

---
 src/box.c |   96 +++++++++++++++++++++++++++++++++++++++++++++--
 1 files changed, 91 insertions(+), 5 deletions(-)

diff --git a/src/box.c b/src/box.c
index 9568599..718215f 100644
--- a/src/box.c
+++ b/src/box.c
@@ -232,7 +232,7 @@
 
 typedef struct{
     int index;
-    int class;
+    int class_id;
     float **probs;
 } sortable_bbox;
 
@@ -240,26 +240,26 @@
 {
     sortable_bbox a = *(sortable_bbox *)pa;
     sortable_bbox b = *(sortable_bbox *)pb;
-    float diff = a.probs[a.index][b.class] - b.probs[b.index][b.class];
+    float diff = a.probs[a.index][b.class_id] - b.probs[b.index][b.class_id];
     if(diff < 0) return 1;
     else if(diff > 0) return -1;
     return 0;
 }
 
-void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh)
+void do_nms_sort_v2(box *boxes, float **probs, int total, int classes, float thresh)
 {
     int i, j, k;
     sortable_bbox *s = calloc(total, sizeof(sortable_bbox));
 
     for(i = 0; i < total; ++i){
         s[i].index = i;       
-        s[i].class = 0;
+        s[i].class_id = 0;
         s[i].probs = probs;
     }
 
     for(k = 0; k < classes; ++k){
         for(i = 0; i < total; ++i){
-            s[i].class = k;
+            s[i].class_id = k;
         }
         qsort(s, total, sizeof(sortable_bbox), nms_comparator);
         for(i = 0; i < total; ++i){
@@ -276,6 +276,92 @@
     free(s);
 }
 
+int nms_comparator_v3(const void *pa, const void *pb)
+{
+    detection a = *(detection *)pa;
+    detection b = *(detection *)pb;
+    float diff = 0;
+    if (b.sort_class >= 0) {
+        diff = a.prob[b.sort_class] - b.prob[b.sort_class];
+    }
+    else {
+        diff = a.objectness - b.objectness;
+    }
+    if (diff < 0) return 1;
+    else if (diff > 0) return -1;
+    return 0;
+}
+
+void do_nms_obj(detection *dets, int total, int classes, float thresh)
+{
+    int i, j, k;
+    k = total - 1;
+    for (i = 0; i <= k; ++i) {
+        if (dets[i].objectness == 0) {
+            detection swap = dets[i];
+            dets[i] = dets[k];
+            dets[k] = swap;
+            --k;
+            --i;
+        }
+    }
+    total = k + 1;
+
+    for (i = 0; i < total; ++i) {
+        dets[i].sort_class = -1;
+    }
+
+    qsort(dets, total, sizeof(detection), nms_comparator_v3);
+    for (i = 0; i < total; ++i) {
+        if (dets[i].objectness == 0) continue;
+        box a = dets[i].bbox;
+        for (j = i + 1; j < total; ++j) {
+            if (dets[j].objectness == 0) continue;
+            box b = dets[j].bbox;
+            if (box_iou(a, b) > thresh) {
+                dets[j].objectness = 0;
+                for (k = 0; k < classes; ++k) {
+                    dets[j].prob[k] = 0;
+                }
+            }
+        }
+    }
+}
+
+void do_nms_sort(detection *dets, int total, int classes, float thresh)
+{
+    int i, j, k;
+    k = total - 1;
+    for (i = 0; i <= k; ++i) {
+        if (dets[i].objectness == 0) {
+            detection swap = dets[i];
+            dets[i] = dets[k];
+            dets[k] = swap;
+            --k;
+            --i;
+        }
+    }
+    total = k + 1;
+
+    for (k = 0; k < classes; ++k) {
+        for (i = 0; i < total; ++i) {
+            dets[i].sort_class = k;
+        }
+        qsort(dets, total, sizeof(detection), nms_comparator_v3);
+        for (i = 0; i < total; ++i) {
+            //printf("  k = %d, \t i = %d \n", k, i);
+            if (dets[i].prob[k] == 0) continue;
+            box a = dets[i].bbox;
+            for (j = i + 1; j < total; ++j) {
+                box b = dets[j].bbox;
+                if (box_iou(a, b) > thresh) {
+                    dets[j].prob[k] = 0;
+                }
+            }
+        }
+    }
+}
+
 void do_nms(box *boxes, float **probs, int total, int classes, float thresh)
 {
     int i, j, k;

--
Gitblit v1.10.0