AlexeyAB
2018-02-15 a1af57d8d60b50e8188f36b7f74752c8cc124177
src/network.c
@@ -50,6 +50,7 @@
    int batch_num = get_current_batch(net);
    int i;
    float rate;
   if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
    switch (net.policy) {
        case CONSTANT:
            return net.learning_rate;
@@ -66,8 +67,9 @@
        case EXP:
            return net.learning_rate * pow(net.gamma, batch_num);
        case POLY:
            if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
            return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
         return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
            //if (batch_num < net.burn_in) return net.learning_rate * pow((float)batch_num / net.burn_in, net.power);
            //return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
        case RANDOM:
            return net.learning_rate * pow(rand_uniform(0,1), net.power);
        case SIG: