From 46e1b263e1f9a37da4df224b11937d2480eb27d9 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 15 May 2015 17:25:05 +0000
Subject: [PATCH] testing other losses

---
 src/detection_layer.c |   39 ++++++++++++++++++++-------------------
 1 files changed, 20 insertions(+), 19 deletions(-)

diff --git a/src/detection_layer.c b/src/detection_layer.c
index 395146b..dd68244 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -330,9 +330,8 @@
             l.output[out_i++] = mask*state.input[in_i++];
         }
     }
-    if(l.does_cost && state.train && 0){
+    if(l.does_cost && state.train){
         int count = 0;
-        float avg = 0;
         *(l.cost) = 0;
         int size = get_detection_layer_output_size(l) * l.batch;
         memset(l.delta, 0, size * sizeof(float));
@@ -354,26 +353,28 @@
             out.w = l.output[j+2];
             out.h = l.output[j+3];
             if(!(truth.w*truth.h)) continue;
-            //printf("iou: %f\n", iou);
-            dbox d = diou(out, truth);
-            l.delta[j+0] = d.dx;
-            l.delta[j+1] = d.dy;
-            l.delta[j+2] = d.dw;
-            l.delta[j+3] = d.dh;
+            l.delta[j+0] = (truth.x - out.x);
+            l.delta[j+1] = (truth.y - out.y);
+            l.delta[j+2] = (truth.w - out.w);
+            l.delta[j+3] = (truth.h - out.h);
+            *(l.cost) += pow((out.x - truth.x), 2);
+            *(l.cost) += pow((out.y - truth.y), 2);
+            *(l.cost) += pow((out.w - truth.w), 2);
+            *(l.cost) += pow((out.h - truth.h), 2);
 
-            int sqr = 1;
-            if(sqr){
-                truth.w *= truth.w;
-                truth.h *= truth.h;
-                out.w *= out.w;
-                out.h *= out.h;
-            }
-            float iou = box_iou(truth, out);
-            *(l.cost) += pow((1-iou), 2);
-            avg += iou;
+/*
+            l.delta[j+0] = .1 * (truth.x - out.x) / (49 * truth.w * truth.w);
+            l.delta[j+1] = .1 * (truth.y - out.y) / (49 * truth.h * truth.h);
+            l.delta[j+2] = .1 * (truth.w - out.w) / (     truth.w * truth.w);
+            l.delta[j+3] = .1 * (truth.h - out.h) / (     truth.h * truth.h);
+
+            *(l.cost) += pow((out.x - truth.x)/truth.w/7., 2);
+            *(l.cost) += pow((out.y - truth.y)/truth.h/7., 2);
+            *(l.cost) += pow((out.w - truth.w)/truth.w, 2);
+            *(l.cost) += pow((out.h - truth.h)/truth.h, 2);
+            */
             ++count;
         }
-        fprintf(stderr, "Avg IOU: %f\n", avg/count);
     }
     /*
        int count = 0;

--
Gitblit v1.10.0