From 37d7c1e79f65a75caf87e29a562d30c51cd654e5 Mon Sep 17 00:00:00 2001
From: Joe Redmon <pjreddie@gmail.com>
Date: Thu, 26 Nov 2015 21:52:56 +0000
Subject: [PATCH] fixed label linking
---
src/network.c | 21 +++++++++++++++++++++
1 files changed, 21 insertions(+), 0 deletions(-)
diff --git a/src/network.c b/src/network.c
index 9bcb264..d9585c4 100644
--- a/src/network.c
+++ b/src/network.c
@@ -8,6 +8,7 @@
#include "crop_layer.h"
#include "connected_layer.h"
+#include "local_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
@@ -25,6 +26,17 @@
return batch_num;
}
+void reset_momentum(network net)
+{
+ if (net.momentum == 0) return;
+ net.learning_rate = 0;
+ net.momentum = 0;
+ net.decay = 0;
+ #ifdef GPU
+ if(gpu_index >= 0) update_network_gpu(net);
+ #endif
+}
+
float get_current_rate(network net)
{
int batch_num = get_current_batch(net);
@@ -40,6 +52,7 @@
for(i = 0; i < net.num_steps; ++i){
if(net.steps[i] > batch_num) return rate;
rate *= net.scales[i];
+ if(net.steps[i] > batch_num - 1) reset_momentum(net);
}
return rate;
case EXP:
@@ -59,6 +72,8 @@
switch(a){
case CONVOLUTIONAL:
return "convolutional";
+ case LOCAL:
+ return "local";
case DECONVOLUTIONAL:
return "deconvolutional";
case CONNECTED:
@@ -112,6 +127,8 @@
forward_convolutional_layer(l, state);
} else if(l.type == DECONVOLUTIONAL){
forward_deconvolutional_layer(l, state);
+ } else if(l.type == LOCAL){
+ forward_local_layer(l, state);
} else if(l.type == NORMALIZATION){
forward_normalization_layer(l, state);
} else if(l.type == DETECTION){
@@ -150,6 +167,8 @@
update_deconvolutional_layer(l, rate, net.momentum, net.decay);
} else if(l.type == CONNECTED){
update_connected_layer(l, update_batch, rate, net.momentum, net.decay);
+ } else if(l.type == LOCAL){
+ update_local_layer(l, update_batch, rate, net.momentum, net.decay);
}
}
}
@@ -219,6 +238,8 @@
if(i != 0) backward_softmax_layer(l, state);
} else if(l.type == CONNECTED){
backward_connected_layer(l, state);
+ } else if(l.type == LOCAL){
+ backward_local_layer(l, state);
} else if(l.type == COST){
backward_cost_layer(l, state);
} else if(l.type == ROUTE){
--
Gitblit v1.10.0