From b5936b499abc94c0efffbcc99b5698574b59d860 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sat, 05 Sep 2015 00:52:44 +0000
Subject: [PATCH] lots of stuff
---
src/network_kernels.cu | 9 +++++----
1 files changed, 5 insertions(+), 4 deletions(-)
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index a73ddd9..1f0a654 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -116,14 +116,15 @@
{
int i;
int update_batch = net.batch*net.subdivisions;
+ float rate = get_current_rate(net);
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
if(l.type == CONVOLUTIONAL){
- update_convolutional_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay);
+ update_convolutional_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
} else if(l.type == DECONVOLUTIONAL){
- update_deconvolutional_layer_gpu(l, net.learning_rate, net.momentum, net.decay);
+ update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay);
} else if(l.type == CONNECTED){
- update_connected_layer_gpu(l, update_batch, net.learning_rate, net.momentum, net.decay);
+ update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
}
}
}
@@ -147,7 +148,7 @@
forward_network_gpu(net, state);
backward_network_gpu(net, state);
float error = get_network_cost(net);
- if ((net.seen / net.batch) % net.subdivisions == 0) update_network_gpu(net);
+ if (((*net.seen) / net.batch) % net.subdivisions == 0) update_network_gpu(net);
return error;
}
--
Gitblit v1.10.0