From 0cbfa4646128206300b9a30586615c3698abfb76 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 08 May 2015 17:33:47 +0000
Subject: [PATCH] stuff
---
src/network.c | 30 +++++++++++++++++++++++++++---
1 files changed, 27 insertions(+), 3 deletions(-)
diff --git a/src/network.c b/src/network.c
index 3247a31..01a6128 100644
--- a/src/network.c
+++ b/src/network.c
@@ -4,7 +4,6 @@
#include "image.h"
#include "data.h"
#include "utils.h"
-#include "params.h"
#include "crop_layer.h"
#include "connected_layer.h"
@@ -16,6 +15,7 @@
#include "normalization_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
+#include "route_layer.h"
char *get_layer_string(LAYER_TYPE a)
{
@@ -40,6 +40,8 @@
return "crop";
case COST:
return "cost";
+ case ROUTE:
+ return "route";
default:
break;
}
@@ -99,6 +101,9 @@
else if(net.types[i] == DROPOUT){
forward_dropout_layer(*(dropout_layer *)net.layers[i], state);
}
+ else if(net.types[i] == ROUTE){
+ forward_route_layer(*(route_layer *)net.layers[i], net);
+ }
state.input = get_network_output_layer(net, i);
}
}
@@ -143,6 +148,8 @@
return ((crop_layer *)net.layers[i]) -> output;
} else if(net.types[i] == NORMALIZATION){
return ((normalization_layer *)net.layers[i]) -> output;
+ } else if(net.types[i] == ROUTE){
+ return ((route_layer *)net.layers[i]) -> output;
}
return 0;
}
@@ -177,6 +184,8 @@
} else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
return layer.delta;
+ } else if(net.types[i] == ROUTE){
+ return ((route_layer *)net.layers[i]) -> delta;
}
return 0;
}
@@ -247,10 +256,12 @@
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
backward_connected_layer(layer, state);
- }
- else if(net.types[i] == COST){
+ } else if(net.types[i] == COST){
cost_layer layer = *(cost_layer *)net.layers[i];
backward_cost_layer(layer, state);
+ } else if(net.types[i] == ROUTE){
+ route_layer layer = *(route_layer *)net.layers[i];
+ backward_route_layer(layer, net);
}
}
}
@@ -369,6 +380,10 @@
crop_layer *layer = (crop_layer *)net->layers[i];
layer->batch = b;
}
+ else if(net->types[i] == ROUTE){
+ route_layer *layer = (route_layer *)net->layers[i];
+ layer->batch = b;
+ }
}
}
@@ -445,12 +460,17 @@
softmax_layer layer = *(softmax_layer *)net.layers[i];
return layer.inputs;
}
+ else if(net.types[i] == ROUTE){
+ route_layer layer = *(route_layer *)net.layers[i];
+ return layer.outputs;
+ }
fprintf(stderr, "Can't find output size\n");
return 0;
}
int resize_network(network net, int h, int w, int c)
{
+ fprintf(stderr, "Might be broken, careful!!");
int i;
for (i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
@@ -540,6 +560,10 @@
crop_layer layer = *(crop_layer *)net.layers[i];
return get_crop_image(layer);
}
+ else if(net.types[i] == ROUTE){
+ route_layer layer = *(route_layer *)net.layers[i];
+ return get_network_image_layer(net, layer.input_layers[0]);
+ }
return make_empty_image(0,0,0);
}
--
Gitblit v1.10.0