From 0cbfa4646128206300b9a30586615c3698abfb76 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 08 May 2015 17:33:47 +0000
Subject: [PATCH] stuff

---
 src/network.c |   63 +++++++++++++++++++++----------
 1 files changed, 42 insertions(+), 21 deletions(-)

diff --git a/src/network.c b/src/network.c
index 61200d3..01a6128 100644
--- a/src/network.c
+++ b/src/network.c
@@ -4,7 +4,6 @@
 #include "image.h"
 #include "data.h"
 #include "utils.h"
-#include "params.h"
 
 #include "crop_layer.h"
 #include "connected_layer.h"
@@ -16,6 +15,7 @@
 #include "normalization_layer.h"
 #include "softmax_layer.h"
 #include "dropout_layer.h"
+#include "route_layer.h"
 
 char *get_layer_string(LAYER_TYPE a)
 {
@@ -40,6 +40,8 @@
             return "crop";
         case COST:
             return "cost";
+        case ROUTE:
+            return "route";
         default:
             break;
     }
@@ -99,6 +101,9 @@
         else if(net.types[i] == DROPOUT){
             forward_dropout_layer(*(dropout_layer *)net.layers[i], state);
         }
+        else if(net.types[i] == ROUTE){
+            forward_route_layer(*(route_layer *)net.layers[i], net);
+        }
         state.input = get_network_output_layer(net, i);
     }
 }
@@ -143,6 +148,8 @@
         return ((crop_layer *)net.layers[i]) -> output;
     } else if(net.types[i] == NORMALIZATION){
         return ((normalization_layer *)net.layers[i]) -> output;
+    } else if(net.types[i] == ROUTE){
+        return ((route_layer *)net.layers[i]) -> output;
     }
     return 0;
 }
@@ -177,6 +184,8 @@
     } else if(net.types[i] == CONNECTED){
         connected_layer layer = *(connected_layer *)net.layers[i];
         return layer.delta;
+    } else if(net.types[i] == ROUTE){
+        return ((route_layer *)net.layers[i]) -> delta;
     }
     return 0;
 }
@@ -186,6 +195,9 @@
     if(net.types[net.n-1] == COST){
         return ((cost_layer *)net.layers[net.n-1])->output[0];
     }
+    if(net.types[net.n-1] == DETECTION){
+        return ((detection_layer *)net.layers[net.n-1])->cost[0];
+    }
     return 0;
 }
 
@@ -194,24 +206,6 @@
     return get_network_delta_layer(net, net.n-1);
 }
 
-float calculate_error_network(network net, float *truth)
-{
-    float sum = 0;
-    float *delta = get_network_delta(net);
-    float *out = get_network_output(net);
-    int i;
-    for(i = 0; i < get_network_output_size(net)*net.batch; ++i){
-        //if(i %get_network_output_size(net) == 0) printf("\n");
-        //printf("%5.2f %5.2f, ", out[i], truth[i]);
-        //if(i == get_network_output_size(net)) printf("\n");
-        delta[i] = truth[i] - out[i];
-        //printf("%.10f, ", out[i]);
-        sum += delta[i]*delta[i];
-    }
-    //printf("\n");
-    return sum;
-}
-
 int get_predicted_class_network(network net)
 {
     float *out = get_network_output(net);
@@ -262,10 +256,12 @@
         else if(net.types[i] == CONNECTED){
             connected_layer layer = *(connected_layer *)net.layers[i];
             backward_connected_layer(layer, state);
-        }
-        else if(net.types[i] == COST){
+        } else if(net.types[i] == COST){
             cost_layer layer = *(cost_layer *)net.layers[i];
             backward_cost_layer(layer, state);
+        } else if(net.types[i] == ROUTE){
+            route_layer layer = *(route_layer *)net.layers[i];
+            backward_route_layer(layer, net);
         }
     }
 }
@@ -384,6 +380,10 @@
             crop_layer *layer = (crop_layer *)net->layers[i];
             layer->batch = b;
         }
+        else if(net->types[i] == ROUTE){
+            route_layer *layer = (route_layer *)net->layers[i];
+            layer->batch = b;
+        }
     }
 }
 
@@ -460,12 +460,17 @@
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.inputs;
     }
+    else if(net.types[i] == ROUTE){
+        route_layer layer = *(route_layer *)net.layers[i];
+        return layer.outputs;
+    }
     fprintf(stderr, "Can't find output size\n");
     return 0;
 }
 
 int resize_network(network net, int h, int w, int c)
 {
+    fprintf(stderr, "Might be broken, careful!!");
     int i;
     for (i = 0; i < net.n; ++i){
         if(net.types[i] == CONVOLUTIONAL){
@@ -518,6 +523,18 @@
     return get_network_input_size_layer(net, 0);
 }
 
+detection_layer *get_network_detection_layer(network net)
+{
+    int i;
+    for(i = 0; i < net.n; ++i){
+        if(net.types[i] == DETECTION){
+            detection_layer *layer = (detection_layer *)net.layers[i];
+            return layer;
+        }
+    }
+    return 0;
+}
+
 image get_network_image_layer(network net, int i)
 {
     if(net.types[i] == CONVOLUTIONAL){
@@ -543,6 +560,10 @@
         crop_layer layer = *(crop_layer *)net.layers[i];
         return get_crop_image(layer);
     }
+    else if(net.types[i] == ROUTE){
+        route_layer layer = *(route_layer *)net.layers[i];
+        return get_network_image_layer(net, layer.input_layers[0]);
+    }
     return make_empty_image(0,0,0);
 }
 

--
Gitblit v1.10.0