Joseph Redmon
2015-03-25 e92f7d301c971b4d27aa3dcd1e4047e94f04b3fc
src/network.c
@@ -106,10 +106,11 @@
void update_network(network net)
{
    int i;
    int update_batch = net.batch*net.subdivisions;
    for(i = 0; i < net.n; ++i){
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            update_convolutional_layer(layer, net.learning_rate, net.momentum, net.decay);
            update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
        else if(net.types[i] == DECONVOLUTIONAL){
            deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
@@ -117,7 +118,7 @@
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            update_connected_layer(layer, net.learning_rate, net.momentum, net.decay);
            update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
    }
}
@@ -193,24 +194,6 @@
    return get_network_delta_layer(net, net.n-1);
}
float calculate_error_network(network net, float *truth)
{
    float sum = 0;
    float *delta = get_network_delta(net);
    float *out = get_network_output(net);
    int i;
    for(i = 0; i < get_network_output_size(net)*net.batch; ++i){
        //if(i %get_network_output_size(net) == 0) printf("\n");
        //printf("%5.2f %5.2f, ", out[i], truth[i]);
        //if(i == get_network_output_size(net)) printf("\n");
        delta[i] = truth[i] - out[i];
        //printf("%.10f, ", out[i]);
        sum += delta[i]*delta[i];
    }
    //printf("\n");
    return sum;
}
int get_predicted_class_network(network net)
{
    float *out = get_network_output(net);
@@ -281,7 +264,7 @@
    forward_network(net, state);
    backward_network(net, state);
    float error = get_network_cost(net);
    update_network(net);
    if((net.seen/net.batch)%net.subdivisions == 0) update_network(net);
    return error;
}
@@ -294,6 +277,7 @@
    int i;
    float sum = 0;
    for(i = 0; i < n; ++i){
        net.seen += batch;
        get_random_batch(d, batch, X, y);
        float err = train_network_datum(net, X, y);
        sum += err;
@@ -314,6 +298,7 @@
    float sum = 0;
    for(i = 0; i < n; ++i){
        get_next_batch(d, batch, i*batch, X, y);
        net.seen += batch;
        float err = train_network_datum(net, X, y);
        sum += err;
    }