From 0cbfa4646128206300b9a30586615c3698abfb76 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 08 May 2015 17:33:47 +0000
Subject: [PATCH] stuff

---
 src/network.c |   30 +++++++++++++++++++++++++++---
 1 files changed, 27 insertions(+), 3 deletions(-)

diff --git a/src/network.c b/src/network.c
index 3247a31..01a6128 100644
--- a/src/network.c
+++ b/src/network.c
@@ -4,7 +4,6 @@
 #include "image.h"
 #include "data.h"
 #include "utils.h"
-#include "params.h"
 
 #include "crop_layer.h"
 #include "connected_layer.h"
@@ -16,6 +15,7 @@
 #include "normalization_layer.h"
 #include "softmax_layer.h"
 #include "dropout_layer.h"
+#include "route_layer.h"
 
 char *get_layer_string(LAYER_TYPE a)
 {
@@ -40,6 +40,8 @@
             return "crop";
         case COST:
             return "cost";
+        case ROUTE:
+            return "route";
         default:
             break;
     }
@@ -99,6 +101,9 @@
         else if(net.types[i] == DROPOUT){
             forward_dropout_layer(*(dropout_layer *)net.layers[i], state);
         }
+        else if(net.types[i] == ROUTE){
+            forward_route_layer(*(route_layer *)net.layers[i], net);
+        }
         state.input = get_network_output_layer(net, i);
     }
 }
@@ -143,6 +148,8 @@
         return ((crop_layer *)net.layers[i]) -> output;
     } else if(net.types[i] == NORMALIZATION){
         return ((normalization_layer *)net.layers[i]) -> output;
+    } else if(net.types[i] == ROUTE){
+        return ((route_layer *)net.layers[i]) -> output;
     }
     return 0;
 }
@@ -177,6 +184,8 @@
     } else if(net.types[i] == CONNECTED){
         connected_layer layer = *(connected_layer *)net.layers[i];
         return layer.delta;
+    } else if(net.types[i] == ROUTE){
+        return ((route_layer *)net.layers[i]) -> delta;
     }
     return 0;
 }
@@ -247,10 +256,12 @@
         else if(net.types[i] == CONNECTED){
             connected_layer layer = *(connected_layer *)net.layers[i];
             backward_connected_layer(layer, state);
-        }
-        else if(net.types[i] == COST){
+        } else if(net.types[i] == COST){
             cost_layer layer = *(cost_layer *)net.layers[i];
             backward_cost_layer(layer, state);
+        } else if(net.types[i] == ROUTE){
+            route_layer layer = *(route_layer *)net.layers[i];
+            backward_route_layer(layer, net);
         }
     }
 }
@@ -369,6 +380,10 @@
             crop_layer *layer = (crop_layer *)net->layers[i];
             layer->batch = b;
         }
+        else if(net->types[i] == ROUTE){
+            route_layer *layer = (route_layer *)net->layers[i];
+            layer->batch = b;
+        }
     }
 }
 
@@ -445,12 +460,17 @@
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.inputs;
     }
+    else if(net.types[i] == ROUTE){
+        route_layer layer = *(route_layer *)net.layers[i];
+        return layer.outputs;
+    }
     fprintf(stderr, "Can't find output size\n");
     return 0;
 }
 
 int resize_network(network net, int h, int w, int c)
 {
+    fprintf(stderr, "Might be broken, careful!!");
     int i;
     for (i = 0; i < net.n; ++i){
         if(net.types[i] == CONVOLUTIONAL){
@@ -540,6 +560,10 @@
         crop_layer layer = *(crop_layer *)net.layers[i];
         return get_crop_image(layer);
     }
+    else if(net.types[i] == ROUTE){
+        route_layer layer = *(route_layer *)net.layers[i];
+        return get_network_image_layer(net, layer.input_layers[0]);
+    }
     return make_empty_image(0,0,0);
 }
 

--
Gitblit v1.10.0