From 481b57a96a9ef29b112caec1bb3e17ffb043ceae Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 25 Sep 2016 06:12:54 +0000
Subject: [PATCH] So I have this new programming paradigm.......

---
 src/box.c |   71 ++++++++++++++++++++++++++++++++++-
 1 files changed, 68 insertions(+), 3 deletions(-)

diff --git a/src/box.c b/src/box.c
index d49be41..9568599 100644
--- a/src/box.c
+++ b/src/box.c
@@ -1,6 +1,17 @@
 #include "box.h"
 #include <stdio.h>
 #include <math.h>
+#include <stdlib.h>
+
+box float_to_box(float *f)
+{
+    box b;
+    b.x = f[0];
+    b.y = f[1];
+    b.w = f[2];
+    b.h = f[3];
+    return b;
+}
 
 dbox derivative(box a, box b)
 {
@@ -85,6 +96,14 @@
     return box_intersection(a, b)/box_union(a, b);
 }
 
+float box_rmse(box a, box b)
+{
+    return sqrt(pow(a.x-b.x, 2) + 
+                pow(a.y-b.y, 2) + 
+                pow(a.w-b.w, 2) + 
+                pow(a.h-b.h, 2));
+}
+
 dbox dintersect(box a, box b)
 {
     float w = overlap(a.x, a.w, b.x, b.w);
@@ -211,16 +230,62 @@
     return dd;
 }
 
-void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh)
+typedef struct{
+    int index;
+    int class;
+    float **probs;
+} sortable_bbox;
+
+int nms_comparator(const void *pa, const void *pb)
+{
+    sortable_bbox a = *(sortable_bbox *)pa;
+    sortable_bbox b = *(sortable_bbox *)pb;
+    float diff = a.probs[a.index][b.class] - b.probs[b.index][b.class];
+    if(diff < 0) return 1;
+    else if(diff > 0) return -1;
+    return 0;
+}
+
+void do_nms_sort(box *boxes, float **probs, int total, int classes, float thresh)
 {
     int i, j, k;
-    for(i = 0; i < num_boxes*num_boxes; ++i){
+    sortable_bbox *s = calloc(total, sizeof(sortable_bbox));
+
+    for(i = 0; i < total; ++i){
+        s[i].index = i;       
+        s[i].class = 0;
+        s[i].probs = probs;
+    }
+
+    for(k = 0; k < classes; ++k){
+        for(i = 0; i < total; ++i){
+            s[i].class = k;
+        }
+        qsort(s, total, sizeof(sortable_bbox), nms_comparator);
+        for(i = 0; i < total; ++i){
+            if(probs[s[i].index][k] == 0) continue;
+            box a = boxes[s[i].index];
+            for(j = i+1; j < total; ++j){
+                box b = boxes[s[j].index];
+                if (box_iou(a, b) > thresh){
+                    probs[s[j].index][k] = 0;
+                }
+            }
+        }
+    }
+    free(s);
+}
+
+void do_nms(box *boxes, float **probs, int total, int classes, float thresh)
+{
+    int i, j, k;
+    for(i = 0; i < total; ++i){
         int any = 0;
         for(k = 0; k < classes; ++k) any = any || (probs[i][k] > 0);
         if(!any) {
             continue;
         }
-        for(j = i+1; j < num_boxes*num_boxes; ++j){
+        for(j = i+1; j < total; ++j){
             if (box_iou(boxes[i], boxes[j]) > thresh){
                 for(k = 0; k < classes; ++k){
                     if (probs[i][k] < probs[j][k]) probs[i][k] = 0;

--
Gitblit v1.10.0