From f98bf6bbdb5ed81f2ea2071ad8e705130f7ba596 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sat, 28 Mar 2015 23:11:37 +0000
Subject: [PATCH] We do our OWN resizing!
---
src/network.c | 27 ++++++---------------------
1 files changed, 6 insertions(+), 21 deletions(-)
diff --git a/src/network.c b/src/network.c
index 89c5621..75c9454 100644
--- a/src/network.c
+++ b/src/network.c
@@ -106,10 +106,11 @@
void update_network(network net)
{
int i;
+ int update_batch = net.batch*net.subdivisions;
for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
- update_convolutional_layer(layer, net.learning_rate, net.momentum, net.decay);
+ update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
}
else if(net.types[i] == DECONVOLUTIONAL){
deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
@@ -117,7 +118,7 @@
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
- update_connected_layer(layer, net.learning_rate, net.momentum, net.decay);
+ update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
}
}
}
@@ -193,24 +194,6 @@
return get_network_delta_layer(net, net.n-1);
}
-float calculate_error_network(network net, float *truth)
-{
- float sum = 0;
- float *delta = get_network_delta(net);
- float *out = get_network_output(net);
- int i;
- for(i = 0; i < get_network_output_size(net)*net.batch; ++i){
- //if(i %get_network_output_size(net) == 0) printf("\n");
- //printf("%5.2f %5.2f, ", out[i], truth[i]);
- //if(i == get_network_output_size(net)) printf("\n");
- delta[i] = truth[i] - out[i];
- //printf("%.10f, ", out[i]);
- sum += delta[i]*delta[i];
- }
- //printf("\n");
- return sum;
-}
-
int get_predicted_class_network(network net)
{
float *out = get_network_output(net);
@@ -281,7 +264,7 @@
forward_network(net, state);
backward_network(net, state);
float error = get_network_cost(net);
- update_network(net);
+ if((net.seen/net.batch)%net.subdivisions == 0) update_network(net);
return error;
}
@@ -294,6 +277,7 @@
int i;
float sum = 0;
for(i = 0; i < n; ++i){
+ net.seen += batch;
get_random_batch(d, batch, X, y);
float err = train_network_datum(net, X, y);
sum += err;
@@ -314,6 +298,7 @@
float sum = 0;
for(i = 0; i < n; ++i){
get_next_batch(d, batch, i*batch, X, y);
+ net.seen += batch;
float err = train_network_datum(net, X, y);
sum += err;
}
--
Gitblit v1.10.0