Joseph Redmon
2015-03-22 664c5dd2f2d1c4ad177d5122df6ce3e2900c6648
Subdivisions for batches
10 files modified
127 ■■■■■ changed files
src/connected_layer.c 26 ●●●● patch | view | raw | blame | history
src/connected_layer.h 4 ●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 33 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 14 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.h 4 ●●●● patch | view | raw | blame | history
src/cuda.c 1 ●●●● patch | view | raw | blame | history
src/network.c 9 ●●●●● patch | view | raw | blame | history
src/network.h 1 ●●●● patch | view | raw | blame | history
src/network_kernels.cu 31 ●●●● patch | view | raw | blame | history
src/parser.c 4 ●●●● patch | view | raw | blame | history
src/connected_layer.c
@@ -55,13 +55,13 @@
    return layer;
}
void update_connected_layer(connected_layer layer, float learning_rate, float momentum, float decay)
void update_connected_layer(connected_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    axpy_cpu(layer.outputs, learning_rate, layer.bias_updates, 1, layer.biases, 1);
    axpy_cpu(layer.outputs, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1);
    scal_cpu(layer.outputs, momentum, layer.bias_updates, 1);
    axpy_cpu(layer.inputs*layer.outputs, -decay, layer.weights, 1, layer.weight_updates, 1);
    axpy_cpu(layer.inputs*layer.outputs, learning_rate, layer.weight_updates, 1, layer.weights, 1);
    axpy_cpu(layer.inputs*layer.outputs, -decay*batch, layer.weights, 1, layer.weight_updates, 1);
    axpy_cpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates, 1, layer.weights, 1);
    scal_cpu(layer.inputs*layer.outputs, momentum, layer.weight_updates, 1);
}
@@ -84,10 +84,9 @@
void backward_connected_layer(connected_layer layer, network_state state)
{
    int i;
    float alpha = 1./layer.batch;
    gradient_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.delta);
    for(i = 0; i < layer.batch; ++i){
        axpy_cpu(layer.outputs, alpha, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1);
        axpy_cpu(layer.outputs, 1, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1);
    }
    int m = layer.inputs;
    int k = layer.batch;
@@ -95,7 +94,7 @@
    float *a = state.input;
    float *b = layer.delta;
    float *c = layer.weight_updates;
    gemm(1,0,m,n,k,alpha,a,m,b,n,1,c,n);
    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
    m = layer.batch;
    k = layer.outputs;
@@ -126,13 +125,13 @@
    cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.outputs);
}
void update_connected_layer_gpu(connected_layer layer, float learning_rate, float momentum, float decay)
void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    axpy_ongpu(layer.outputs, learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    axpy_ongpu(layer.outputs, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.outputs, momentum, layer.bias_updates_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, -decay, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, learning_rate, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
    scal_ongpu(layer.inputs*layer.outputs, momentum, layer.weight_updates_gpu, 1);
}
@@ -154,11 +153,10 @@
void backward_connected_layer_gpu(connected_layer layer, network_state state)
{
    float alpha = 1./layer.batch;
    int i;
    gradient_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation, layer.delta_gpu);
    for(i = 0; i < layer.batch; ++i){
        axpy_ongpu_offset(layer.outputs, alpha, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1);
        axpy_ongpu_offset(layer.outputs, 1, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1);
    }
    int m = layer.inputs;
    int k = layer.batch;
@@ -166,7 +164,7 @@
    float * a = state.input;
    float * b = layer.delta_gpu;
    float * c = layer.weight_updates_gpu;
    gemm_ongpu(1,0,m,n,k,alpha,a,m,b,n,1,c,n);
    gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
    m = layer.batch;
    k = layer.outputs;
src/connected_layer.h
@@ -38,12 +38,12 @@
void forward_connected_layer(connected_layer layer, network_state state);
void backward_connected_layer(connected_layer layer, network_state state);
void update_connected_layer(connected_layer layer, float learning_rate, float momentum, float decay);
void update_connected_layer(connected_layer layer, int batch, float learning_rate, float momentum, float decay);
#ifdef GPU
void forward_connected_layer_gpu(connected_layer layer, network_state state);
void backward_connected_layer_gpu(connected_layer layer, network_state state);
void update_connected_layer_gpu(connected_layer layer, float learning_rate, float momentum, float decay);
void update_connected_layer_gpu(connected_layer layer, int batch, float learning_rate, float momentum, float decay);
void push_connected_layer(connected_layer layer);
void pull_connected_layer(connected_layer layer);
#endif
src/convolutional_kernels.cu
@@ -48,15 +48,12 @@
extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
    float alpha = 1./batch;
    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, alpha);
    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, 1);
    check_error(cudaPeekAtLastError());
}
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
{
//clock_t time = clock();
    int i;
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
@@ -64,36 +61,18 @@
        convolutional_out_width(layer);
    bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
//cudaDeviceSynchronize();
//printf("bias %f\n", sec(clock() - time));
//time = clock();
//float imt=0;
//float gemt = 0;
    for(i = 0; i < layer.batch; ++i){
//time = clock();
        im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
//cudaDeviceSynchronize();
//imt += sec(clock()-time);
//time = clock();
        float * a = layer.filters_gpu;
        float * b = layer.col_image_gpu;
        float * c = layer.output_gpu;
        gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
//cudaDeviceSynchronize();
//gemt += sec(clock()-time);
//time = clock();
    }
    activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
//cudaDeviceSynchronize();
//printf("activate %f\n", sec(clock() - time));
//printf("im2col %f\n", imt);
//printf("gemm %f\n", gemt);
}
extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
{
    float alpha = 1./layer.batch;
    int i;
    int m = layer.n;
    int n = layer.size*layer.size*layer.c;
@@ -111,7 +90,7 @@
        float * c = layer.filter_updates_gpu;
        im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
        gemm_ongpu(0,1,m,n,k,alpha,a + i*m*k,k,b,k,1,c,n);
        gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
        if(state.delta){
@@ -142,15 +121,15 @@
    cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
}
extern "C" void update_convolutional_layer_gpu(convolutional_layer layer, float learning_rate, float momentum, float decay)
extern "C" void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    int size = layer.size*layer.size*layer.c*layer.n;
    axpy_ongpu(layer.n, learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
    axpy_ongpu(size, -decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
    axpy_ongpu(size, learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
    axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
    axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
    scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
}
src/convolutional_layer.c
@@ -129,11 +129,10 @@
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
{
    float alpha = 1./batch;
    int i,b;
    for(b = 0; b < batch; ++b){
        for(i = 0; i < n; ++i){
            bias_updates[i] += alpha * sum_array(delta+size*(i+b*n), size);
            bias_updates[i] += sum_array(delta+size*(i+b*n), size);
        }
    }
}
@@ -167,7 +166,6 @@
void backward_convolutional_layer(convolutional_layer layer, network_state state)
{
    float alpha = 1./layer.batch;
    int i;
    int m = layer.n;
    int n = layer.size*layer.size*layer.c;
@@ -188,7 +186,7 @@
        im2col_cpu(im, layer.c, layer.h, layer.w, 
                layer.size, layer.stride, layer.pad, b);
        gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n);
        gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
        if(state.delta){
            a = layer.filters;
@@ -202,14 +200,14 @@
    }
}
void update_convolutional_layer(convolutional_layer layer, float learning_rate, float momentum, float decay)
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    int size = layer.size*layer.size*layer.c*layer.n;
    axpy_cpu(layer.n, learning_rate, layer.bias_updates, 1, layer.biases, 1);
    axpy_cpu(layer.n, learning_rate/batch, layer.bias_updates, 1, layer.biases, 1);
    scal_cpu(layer.n, momentum, layer.bias_updates, 1);
    axpy_cpu(size, -decay, layer.filters, 1, layer.filter_updates, 1);
    axpy_cpu(size, learning_rate, layer.filter_updates, 1, layer.filters, 1);
    axpy_cpu(size, -decay*batch, layer.filters, 1, layer.filter_updates, 1);
    axpy_cpu(size, learning_rate/batch, layer.filter_updates, 1, layer.filters, 1);
    scal_cpu(size, momentum, layer.filter_updates, 1);
}
src/convolutional_layer.h
@@ -41,7 +41,7 @@
#ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state);
void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state);
void update_convolutional_layer_gpu(convolutional_layer layer, float learning_rate, float momentum, float decay);
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
void push_convolutional_layer(convolutional_layer layer);
void pull_convolutional_layer(convolutional_layer layer);
@@ -53,7 +53,7 @@
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation);
void resize_convolutional_layer(convolutional_layer *layer, int h, int w);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
void update_convolutional_layer(convolutional_layer layer, float learning_rate, float momentum, float decay);
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
void backward_convolutional_layer(convolutional_layer layer, network_state state);
src/cuda.c
@@ -66,6 +66,7 @@
    if(!init){
        curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
        curandSetPseudoRandomGeneratorSeed(gen, 0ULL);
        init = 1;
    }
    curandGenerateUniform(gen, x_gpu, n);
    check_error(cudaPeekAtLastError());
src/network.c
@@ -106,10 +106,11 @@
void update_network(network net)
{
    int i;
    int update_batch = net.batch*net.subdivisions;
    for(i = 0; i < net.n; ++i){
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            update_convolutional_layer(layer, net.learning_rate, net.momentum, net.decay);
            update_convolutional_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
        else if(net.types[i] == DECONVOLUTIONAL){
            deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
@@ -117,7 +118,7 @@
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            update_connected_layer(layer, net.learning_rate, net.momentum, net.decay);
            update_connected_layer(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
    }
}
@@ -281,7 +282,7 @@
    forward_network(net, state);
    backward_network(net, state);
    float error = get_network_cost(net);
    update_network(net);
    if((net.seen/net.batch)%net.subdivisions == 0) update_network(net);
    return error;
}
@@ -294,6 +295,7 @@
    int i;
    float sum = 0;
    for(i = 0; i < n; ++i){
        net.seen += batch;
        get_random_batch(d, batch, X, y);
        float err = train_network_datum(net, X, y);
        sum += err;
@@ -314,6 +316,7 @@
    float sum = 0;
    for(i = 0; i < n; ++i){
        get_next_batch(d, batch, i*batch, X, y);
        net.seen += batch;
        float err = train_network_datum(net, X, y);
        sum += err;
    }
src/network.h
@@ -23,6 +23,7 @@
    int n;
    int batch;
    int seen;
    int subdivisions;
    float learning_rate;
    float momentum;
    float decay;
src/network_kernels.cu
@@ -28,7 +28,6 @@
{
    int i;
    for(i = 0; i < net.n; ++i){
//clock_t time = clock();
        if(net.types[i] == CONVOLUTIONAL){
            forward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state);
        }
@@ -57,9 +56,6 @@
            forward_crop_layer_gpu(*(crop_layer *)net.layers[i], state);
        }
        state.input = get_network_output_gpu_layer(net, i);
//cudaDeviceSynchronize();
//printf("forw %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
//time = clock();
    }
}
@@ -68,7 +64,6 @@
    int i;
    float * original_input = state.input;
    for(i = net.n-1; i >= 0; --i){
//clock_t time = clock();
        if(i == 0){
            state.input = original_input;
            state.delta = 0;
@@ -100,19 +95,17 @@
        else if(net.types[i] == SOFTMAX){
            backward_softmax_layer_gpu(*(softmax_layer *)net.layers[i], state);
        }
//cudaDeviceSynchronize();
//printf("back %d: %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
//time = clock();
    }
}
void update_network_gpu(network net)
{
    int i;
    int update_batch = net.batch*net.subdivisions;
    for(i = 0; i < net.n; ++i){
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            update_convolutional_layer_gpu(layer, net.learning_rate, net.momentum, net.decay);
            update_convolutional_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
        else if(net.types[i] == DECONVOLUTIONAL){
            deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
@@ -120,7 +113,7 @@
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            update_connected_layer_gpu(layer, net.learning_rate, net.momentum, net.decay);
            update_connected_layer_gpu(layer, update_batch, net.learning_rate, net.momentum, net.decay);
        }
    }
}
@@ -188,7 +181,6 @@
float train_network_datum_gpu(network net, float *x, float *y)
{
 // clock_t time = clock();
    network_state state;
    int x_size = get_network_input_size(net)*net.batch;
    int y_size = get_network_output_size(net)*net.batch;
@@ -202,26 +194,11 @@
    state.input = *net.input_gpu;
    state.truth = *net.truth_gpu;
    state.train = 1;
//cudaDeviceSynchronize();
//printf("trans %f\n", sec(clock() - time));
//time = clock();
    forward_network_gpu(net, state);
//cudaDeviceSynchronize();
//printf("forw %f\n", sec(clock() - time));
//time = clock();
    backward_network_gpu(net, state);
//cudaDeviceSynchronize();
//printf("back %f\n", sec(clock() - time));
//time = clock();
    update_network_gpu(net);
    if ((net.seen / net.batch) % net.subdivisions == 0) update_network_gpu(net);
    float error = get_network_cost(net);
    //print_letters(y, 50);
    //float *out = get_network_output_gpu(net);
    //print_letters(out, 50);
//cudaDeviceSynchronize();
//printf("updt %f\n", sec(clock() - time));
//time = clock();
    return error;
}
src/parser.c
@@ -249,12 +249,16 @@
    net->momentum = option_find_float(options, "momentum", .9);
    net->decay = option_find_float(options, "decay", .0001);
    net->seen = option_find_int(options, "seen",0);
    int subdivs = option_find_int(options, "subdivisions",1);
    net->batch /= subdivs;
    net->subdivisions = subdivs;
    net->h = option_find_int_quiet(options, "height",0);
    net->w = option_find_int_quiet(options, "width",0);
    net->c = option_find_int_quiet(options, "channels",0);
    net->inputs = option_find_int_quiet(options, "inputs", net->h * net->w * net->c);
    if(!net->inputs && !(net->h && net->w && net->c)) error("No input parameters supplied");
    option_unused(options);
}
network parse_network_cfg(char *filename)