src/network.c
@@ -26,6 +26,17 @@ return batch_num; } void reset_momentum(network net) { if (net.momentum == 0) return; net.learning_rate = 0; net.momentum = 0; net.decay = 0; #ifdef GPU if(gpu_index >= 0) update_network_gpu(net); #endif } float get_current_rate(network net) { int batch_num = get_current_batch(net); @@ -41,6 +52,7 @@ for(i = 0; i < net.num_steps; ++i){ if(net.steps[i] > batch_num) return rate; rate *= net.scales[i]; if(net.steps[i] > batch_num - 1) reset_momentum(net); } return rate; case EXP: