From d8adaf8ea6a31a380f6bf1fe65e88b661d3bb51e Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 21 Oct 2016 20:16:43 +0000
Subject: [PATCH] tree stuff

---
 src/classifier.c |   44 +++++++++++++++++++++++++++++++++-----------
 1 files changed, 33 insertions(+), 11 deletions(-)

diff --git a/src/classifier.c b/src/classifier.c
index 208b7ed..e588af5 100644
--- a/src/classifier.c
+++ b/src/classifier.c
@@ -41,6 +41,20 @@
     return options;
 }
 
+void hierarchy_predictions(float *predictions, int n, tree *hier)
+{
+    int j;
+    for(j = 0; j < n; ++j){
+        int parent = hier->parent[j];
+        if(parent >= 0){
+            predictions[j] *= predictions[parent]; 
+        }
+    }
+    for(j = 0; j < n; ++j){
+        if(!hier->leaf[j]) predictions[j] = 0;
+    }
+}
+
 float *get_regression_values(char **labels, int n)
 {
     float *v = calloc(n, sizeof(float));
@@ -99,7 +113,8 @@
     load_args args = {0};
     args.w = net.w;
     args.h = net.h;
-    args.threads = 16;
+    args.threads = 32;
+    args.hierarchy = net.hierarchy;
 
     args.min = net.min_crop;
     args.max = net.max_crop;
@@ -206,6 +221,7 @@
     args.saturation = net.saturation;
     args.hue = net.hue;
     args.size = net.w;
+    args.hierarchy = net.hierarchy;
 
     args.paths = paths;
     args.classes = classes;
@@ -394,6 +410,7 @@
         float *pred = calloc(classes, sizeof(float));
         for(j = 0; j < 10; ++j){
             float *p = network_predict(net, images[j].data);
+            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
             axpy_cpu(classes, 1, p, 1, pred, 1);
             free_image(images[j]);
         }
@@ -454,6 +471,7 @@
         //show_image(crop, "cropped");
         //cvWaitKey(0);
         float *pred = network_predict(net, resized.data);
+        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
 
         free_image(im);
         free_image(resized);
@@ -513,6 +531,7 @@
         //show_image(crop, "cropped");
         //cvWaitKey(0);
         float *pred = network_predict(net, crop.data);
+        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
 
         if(resized.data != im.data) free_image(resized);
         free_image(im);
@@ -573,6 +592,7 @@
             image r = resize_min(im, scales[j]);
             resize_network(&net, r.w, r.h);
             float *p = network_predict(net, r.data);
+            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
             axpy_cpu(classes, 1, p, 1, pred, 1);
             flip_image(r);
             p = network_predict(net, r.data);
@@ -672,7 +692,6 @@
     }
 }
 
-
 void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename)
 {
     network net = parse_network_cfg(cfgfile);
@@ -713,11 +732,13 @@
         float *X = r.data;
         time=clock();
         float *predictions = network_predict(net, X);
-        top_predictions(net, top, indexes);
+        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
+        top_k(predictions, net.outputs, top, indexes);
         printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
         for(i = 0; i < top; ++i){
             int index = indexes[i];
-            printf("%s: %f\n", names[index], predictions[index]);
+            if(net.hierarchy) printf("%d, %s: %f, parent: %s \n",index, names[index], predictions[index], (net.hierarchy->parent[index] >= 0) ? names[net.hierarchy->parent[index]] : "Root");
+            else printf("%s: %f\n",names[index], predictions[index]);
         }
         if(r.data != im.data) free_image(r);
         free_image(im);
@@ -899,15 +920,15 @@
         float curr_threat = 0;
         if(1){
             curr_threat = predictions[0] * 0 + 
-                            predictions[1] * .6 + 
-                            predictions[2];
+                predictions[1] * .6 + 
+                predictions[2];
         } else {
             curr_threat = predictions[218] +
-                        predictions[539] + 
-                        predictions[540] + 
-                        predictions[368] + 
-                        predictions[369] + 
-                        predictions[370];
+                predictions[539] + 
+                predictions[540] + 
+                predictions[368] + 
+                predictions[369] + 
+                predictions[370];
         }
         threat = roll * curr_threat + (1-roll) * threat;
 
@@ -1092,6 +1113,7 @@
         show_image(in, "Classifier");
 
         float *predictions = network_predict(net, in_s.data);
+        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
         top_predictions(net, top, indexes);
 
         printf("\033[2J");

--
Gitblit v1.10.0