From db0397cfaaf488364e3d2e1669dfefae2ee6ea73 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 14 Dec 2015 19:57:10 +0000
Subject: [PATCH] shortcut layers, msr networks
---
src/network.c | 27 +++++++++++++++++++++++++++
1 files changed, 27 insertions(+), 0 deletions(-)
diff --git a/src/network.c b/src/network.c
index 6c7461d..8dee8cc 100644
--- a/src/network.c
+++ b/src/network.c
@@ -19,6 +19,7 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
+#include "shortcut_layer.h"
int get_current_batch(network net)
{
@@ -26,6 +27,17 @@
return batch_num;
}
+void reset_momentum(network net)
+{
+ if (net.momentum == 0) return;
+ net.learning_rate = 0;
+ net.momentum = 0;
+ net.decay = 0;
+ #ifdef GPU
+ if(gpu_index >= 0) update_network_gpu(net);
+ #endif
+}
+
float get_current_rate(network net)
{
int batch_num = get_current_batch(net);
@@ -41,6 +53,7 @@
for(i = 0; i < net.num_steps; ++i){
if(net.steps[i] > batch_num) return rate;
rate *= net.scales[i];
+ if(net.steps[i] > batch_num - 1) reset_momentum(net);
}
return rate;
case EXP:
@@ -82,6 +95,8 @@
return "cost";
case ROUTE:
return "route";
+ case SHORTCUT:
+ return "shortcut";
case NORMALIZATION:
return "normalization";
default:
@@ -107,6 +122,7 @@
{
int i;
for(i = 0; i < net.n; ++i){
+ state.index = i;
layer l = net.layers[i];
if(l.delta){
scal_cpu(l.outputs * l.batch, 0, l.delta, 1);
@@ -137,6 +153,8 @@
forward_dropout_layer(l, state);
} else if(l.type == ROUTE){
forward_route_layer(l, net);
+ } else if(l.type == SHORTCUT){
+ forward_shortcut_layer(l, state);
}
state.input = l.output;
}
@@ -199,6 +217,7 @@
float *original_input = state.input;
float *original_delta = state.delta;
for(i = net.n-1; i >= 0; --i){
+ state.index = i;
if(i == 0){
state.input = original_input;
state.delta = original_delta;
@@ -232,6 +251,8 @@
backward_cost_layer(l, state);
} else if(l.type == ROUTE){
backward_route_layer(l, net);
+ } else if(l.type == SHORTCUT){
+ backward_shortcut_layer(l, state);
}
}
}
@@ -243,6 +264,8 @@
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
#endif
network_state state;
+ state.index = 0;
+ state.net = net;
state.input = x;
state.delta = 0;
state.truth = y;
@@ -295,6 +318,8 @@
{
int i,j;
network_state state;
+ state.index = 0;
+ state.net = net;
state.train = 1;
state.delta = 0;
float sum = 0;
@@ -431,6 +456,8 @@
#endif
network_state state;
+ state.net = net;
+ state.index = 0;
state.input = input;
state.truth = 0;
state.train = 0;
--
Gitblit v1.10.0