From 46e1b263e1f9a37da4df224b11937d2480eb27d9 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 15 May 2015 17:25:05 +0000
Subject: [PATCH] testing other losses

---
 src/detection.c       |    4 ++--
 src/data.c            |    2 +-
 src/detection_layer.c |   39 ++++++++++++++++++++-------------------
 3 files changed, 23 insertions(+), 22 deletions(-)

diff --git a/src/data.c b/src/data.c
index 902f30c..8e290c4 100644
--- a/src/data.c
+++ b/src/data.c
@@ -165,7 +165,7 @@
 
         w = constrain(0, 1, w);
         h = constrain(0, 1, h);
-        if (w == 0 || h == 0) continue;
+        if (w < .01 || h < .01) continue;
         if(1){
             //w = sqrt(w);
             //h = sqrt(h);
diff --git a/src/detection.c b/src/detection.c
index a1ba888..160fa60 100644
--- a/src/detection.c
+++ b/src/detection.c
@@ -309,8 +309,8 @@
                 float y = (pred.vals[j][ci + 1] + row)/num_boxes;
                 float w = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes);
                 float h = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes);
-                w = pow(w, 2);
-                h = pow(h, 2);
+                w = pow(w, 1);
+                h = pow(h, 1);
                 float prob = scale*pred.vals[j][k+class+background+nuisance];
                 if(prob < threshold) continue;
                 printf("%d %d %f %f %f %f %f\n", offset +  j, class, prob, x, y, w, h);
diff --git a/src/detection_layer.c b/src/detection_layer.c
index 395146b..dd68244 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -330,9 +330,8 @@
             l.output[out_i++] = mask*state.input[in_i++];
         }
     }
-    if(l.does_cost && state.train && 0){
+    if(l.does_cost && state.train){
         int count = 0;
-        float avg = 0;
         *(l.cost) = 0;
         int size = get_detection_layer_output_size(l) * l.batch;
         memset(l.delta, 0, size * sizeof(float));
@@ -354,26 +353,28 @@
             out.w = l.output[j+2];
             out.h = l.output[j+3];
             if(!(truth.w*truth.h)) continue;
-            //printf("iou: %f\n", iou);
-            dbox d = diou(out, truth);
-            l.delta[j+0] = d.dx;
-            l.delta[j+1] = d.dy;
-            l.delta[j+2] = d.dw;
-            l.delta[j+3] = d.dh;
+            l.delta[j+0] = (truth.x - out.x);
+            l.delta[j+1] = (truth.y - out.y);
+            l.delta[j+2] = (truth.w - out.w);
+            l.delta[j+3] = (truth.h - out.h);
+            *(l.cost) += pow((out.x - truth.x), 2);
+            *(l.cost) += pow((out.y - truth.y), 2);
+            *(l.cost) += pow((out.w - truth.w), 2);
+            *(l.cost) += pow((out.h - truth.h), 2);
 
-            int sqr = 1;
-            if(sqr){
-                truth.w *= truth.w;
-                truth.h *= truth.h;
-                out.w *= out.w;
-                out.h *= out.h;
-            }
-            float iou = box_iou(truth, out);
-            *(l.cost) += pow((1-iou), 2);
-            avg += iou;
+/*
+            l.delta[j+0] = .1 * (truth.x - out.x) / (49 * truth.w * truth.w);
+            l.delta[j+1] = .1 * (truth.y - out.y) / (49 * truth.h * truth.h);
+            l.delta[j+2] = .1 * (truth.w - out.w) / (     truth.w * truth.w);
+            l.delta[j+3] = .1 * (truth.h - out.h) / (     truth.h * truth.h);
+
+            *(l.cost) += pow((out.x - truth.x)/truth.w/7., 2);
+            *(l.cost) += pow((out.y - truth.y)/truth.h/7., 2);
+            *(l.cost) += pow((out.w - truth.w)/truth.w, 2);
+            *(l.cost) += pow((out.h - truth.h)/truth.h, 2);
+            */
             ++count;
         }
-        fprintf(stderr, "Avg IOU: %f\n", avg/count);
     }
     /*
        int count = 0;

--
Gitblit v1.10.0