Joseph Redmon
2015-02-04 bfffadc75502cadb5d05909435a2167db5204325
Stable place to commit
9 files modified
94 ■■■■ changed files
src/connected_layer.c 18 ●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 15 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 9 ●●●●● patch | view | raw | blame | history
src/darknet.c 29 ●●●● patch | view | raw | blame | history
src/gemm.c 10 ●●●●● patch | view | raw | blame | history
src/network.c 1 ●●●● patch | view | raw | blame | history
src/network_kernels.cu 1 ●●●● patch | view | raw | blame | history
src/utils.c 10 ●●●●● patch | view | raw | blame | history
src/utils.h 1 ●●●● patch | view | raw | blame | history
src/connected_layer.c
@@ -43,6 +43,7 @@
    for(i = 0; i < outputs; ++i){
        layer->biases[i] = scale;
       // layer->biases[i] = 1;
    }
#ifdef GPU
@@ -113,9 +114,10 @@
void backward_connected_layer(connected_layer layer, float *input, float *delta)
{
    int i;
    float alpha = 1./layer.batch;
    gradient_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.delta);
    for(i = 0; i < layer.batch; ++i){
        axpy_cpu(layer.outputs, 1, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1);
        axpy_cpu(layer.outputs, alpha, layer.delta + i*layer.outputs, 1, layer.bias_updates, 1);
    }
    int m = layer.inputs;
    int k = layer.batch;
@@ -123,7 +125,7 @@
    float *a = input;
    float *b = layer.delta;
    float *c = layer.weight_updates;
    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
    gemm(1,0,m,n,k,alpha,a,m,b,n,1,c,n);
    m = layer.batch;
    k = layer.outputs;
@@ -156,13 +158,18 @@
void update_connected_layer_gpu(connected_layer layer)
{
/*
    cuda_pull_array(layer.weights_gpu, layer.weights, layer.inputs*layer.outputs);
    cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.inputs*layer.outputs);
    printf("Weights: %f updates: %f\n", mag_array(layer.weights, layer.inputs*layer.outputs), layer.learning_rate*mag_array(layer.weight_updates, layer.inputs*layer.outputs));
*/
    axpy_ongpu(layer.outputs, layer.learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.outputs, layer.momentum, layer.bias_updates_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, -layer.decay, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
    axpy_ongpu(layer.inputs*layer.outputs, layer.learning_rate, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
    scal_ongpu(layer.inputs*layer.outputs, layer.momentum, layer.weight_updates_gpu, 1);
    //pull_connected_layer(layer);
}
void forward_connected_layer_gpu(connected_layer layer, float * input)
@@ -183,10 +190,11 @@
void backward_connected_layer_gpu(connected_layer layer, float * input, float * delta)
{
    float alpha = 1./layer.batch;
    int i;
    gradient_array_ongpu(layer.output_gpu, layer.outputs*layer.batch, layer.activation, layer.delta_gpu);
    for(i = 0; i < layer.batch; ++i){
        axpy_ongpu_offset(layer.outputs, 1, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1);
        axpy_ongpu_offset(layer.outputs, alpha, layer.delta_gpu, i*layer.outputs, 1, layer.bias_updates_gpu, 0, 1);
    }
    int m = layer.inputs;
    int k = layer.batch;
@@ -194,7 +202,7 @@
    float * a = input;
    float * b = layer.delta_gpu;
    float * c = layer.weight_updates_gpu;
    gemm_ongpu(1,0,m,n,k,1,a,m,b,n,1,c,n);
    gemm_ongpu(1,0,m,n,k,alpha,a,m,b,n,1,c,n);
    m = layer.batch;
    k = layer.outputs;
src/convolutional_kernels.cu
@@ -28,7 +28,7 @@
    check_error(cudaPeekAtLastError());
}
__global__ void learn_bias(int batch, int n, int size, float *delta, float *bias_updates)
__global__ void learn_bias(int batch, int n, int size, float *delta, float *bias_updates, float scale)
{
    __shared__ float part[BLOCK];
    int i,b;
@@ -44,15 +44,16 @@
    part[p] = sum;
    __syncthreads();
    if(p == 0){
        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += scale * part[i];
    }
}
extern "C" void learn_bias_convolutional_layer_ongpu(convolutional_layer layer)
{
    int size = convolutional_out_height(layer)*convolutional_out_width(layer);
    float alpha = 1./layer.batch;
    learn_bias<<<layer.n, BLOCK>>>(layer.batch, layer.n, size, layer.delta_gpu, layer.bias_updates_gpu);
    learn_bias<<<layer.n, BLOCK>>>(layer.batch, layer.n, size, layer.delta_gpu, layer.bias_updates_gpu, alpha);
    check_error(cudaPeekAtLastError());
}
@@ -99,6 +100,7 @@
extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, float *in, float *delta_gpu)
{
    float alpha = 1./layer.batch;
    int i;
    int m = layer.n;
    int n = layer.size*layer.size*layer.c;
@@ -115,7 +117,7 @@
        float * c = layer.filter_updates_gpu;
        im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
        gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
        gemm_ongpu(0,1,m,n,k,alpha,a + i*m*k,k,b,k,1,c,n);
        if(delta_gpu){
@@ -151,12 +153,9 @@
    int size = layer.size*layer.size*layer.c*layer.n;
/*
    cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
    cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
    cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, size);
    cuda_pull_array(layer.filters_gpu, layer.filters, size);
    printf("Bias: %f updates: %f\n", mse_array(layer.biases, layer.n), mse_array(layer.bias_updates, layer.n));
    printf("Filter: %f updates: %f\n", mse_array(layer.filters, layer.n), mse_array(layer.filter_updates, layer.n));
    printf("Filter: %f updates: %f\n", mag_array(layer.filters, size), layer.learning_rate*mag_array(layer.filter_updates, size));
    */
    axpy_ongpu(layer.n, layer.learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
src/convolutional_layer.c
@@ -66,11 +66,12 @@
    layer->biases = calloc(n, sizeof(float));
    layer->bias_updates = calloc(n, sizeof(float));
    float scale = 1./sqrt(size*size*c);
    //scale = .05;
    //scale = .01;
    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*rand_normal();
    for(i = 0; i < n; ++i){
        //layer->biases[i] = rand_normal()*scale + scale;
        layer->biases[i] = scale;
        //layer->biases[i] = 1;
    }
    int out_h = convolutional_out_height(*layer);
    int out_w = convolutional_out_width(*layer);
@@ -155,18 +156,20 @@
void learn_bias_convolutional_layer(convolutional_layer layer)
{
    float alpha = 1./layer.batch;
    int i,b;
    int size = convolutional_out_height(layer)
        *convolutional_out_width(layer);
    for(b = 0; b < layer.batch; ++b){
        for(i = 0; i < layer.n; ++i){
            layer.bias_updates[i] += sum_array(layer.delta+size*(i+b*layer.n), size);
            layer.bias_updates[i] += alpha * sum_array(layer.delta+size*(i+b*layer.n), size);
        }
    }
}
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta)
{
    float alpha = 1./layer.batch;
    int i;
    int m = layer.n;
    int n = layer.size*layer.size*layer.c;
@@ -188,7 +191,7 @@
        im2col_cpu(im, layer.c, layer.h, layer.w, 
                layer.size, layer.stride, layer.pad, b);
        gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
        gemm(0,1,m,n,k,alpha,a,k,b,k,1,c,n);
        if(delta){
            a = layer.filters;
src/darknet.c
@@ -206,10 +206,28 @@
}
*/
char *basename(char *cfgfile)
{
    char *c = cfgfile;
    char *next;
    while((next = strchr(c, '/')))
    {
        c = next+1;
    }
    c = copy_string(c);
    next = strchr(c, '_');
    if (next) *next = 0;
    next = strchr(c, '.');
    if (next) *next = 0;
    return c;
}
void train_imagenet(char *cfgfile)
{
    float avg_loss = 1;
    float avg_loss = -1;
    srand(time(0));
    char *base = basename(cfgfile);
    printf("%s\n", base);
    network net = parse_network_cfg(cfgfile);
    //test_learn_bias(*(convolutional_layer *)net.layers[1]);
    //set_learning_network(&net, net.learning_rate, 0, net.decay);
@@ -235,12 +253,13 @@
        time=clock();
        float loss = train_network(net, train);
        net.seen += imgs;
        if(avg_loss == -1) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
        free_data(train);
        if(i%100==0){
            char buff[256];
            sprintf(buff, "/home/pjreddie/imagenet_backup/vgg_%d.cfg", i);
            sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.cfg",base, i);
            save_network(net, buff);
        }
    }
@@ -272,7 +291,6 @@
        pthread_join(load_thread, 0);
        val = buffer;
        //normalize_data_rows(val);
        num = (i+1)*m/splits - i*m/splits;
        char **part = paths+(i*m/splits);
@@ -312,6 +330,7 @@
void test_init(char *cfgfile)
{
    gpu_index = -1;
    network net = parse_network_cfg(cfgfile);
    set_batch_network(&net, 1);
    srand(2222222);
@@ -345,7 +364,7 @@
}
void test_dog(char *cfgfile)
{
    image im = load_image_color("data/dog.jpg", 224, 224);
    image im = load_image_color("data/dog.jpg", 256, 256);
    translate_image(im, -128);
    print_image(im);
    float *X = im.data;
@@ -377,7 +396,7 @@
        strtok(filename, "\n");
        image im = load_image_color(filename, 256, 256);
        translate_image(im, -128);
        //scale_image(im, 1/128.);
        scale_image(im, 1/128.);
        printf("%d %d %d\n", im.h, im.w, im.c);
        float *X = im.data;
        time=clock();
src/gemm.c
@@ -276,6 +276,7 @@
int test_gpu_blas()
{
/*
       test_gpu_accuracy(0,0,10,576,75); 
       test_gpu_accuracy(0,0,17,10,10); 
@@ -299,6 +300,15 @@
    time_ongpu(0,0,256,196,2304); 
    time_ongpu(0,0,128,4096,12544); 
    time_ongpu(0,0,128,4096,4096); 
    */
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,576,12544);
    time_ongpu(0,0,256,2304,784);
    time_ongpu(1,1,2304,256,784);
    time_ongpu(0,0,512,4608,196);
    time_ongpu(1,1,4608,512,196);
return 0;
}
src/network.c
@@ -133,7 +133,6 @@
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            //secret_update_connected_layer((connected_layer *)net.layers[i]);
            update_connected_layer(layer);
        }
    }
src/network_kernels.cu
@@ -61,6 +61,7 @@
            forward_crop_layer_gpu(layer, train, input);
            input = layer.output_gpu;
        }
        //cudaDeviceSynchronize();
        //printf("Forward %d %s %f\n", i, get_layer_string(net.types[i]), sec(clock() - time));
    }
}
src/utils.c
@@ -262,6 +262,16 @@
    }
}
float mag_array(float *a, int n)
{
    int i;
    float sum = 0;
    for(i = 0; i < n; ++i){
        sum += a[i]*a[i];
    }
    return sqrt(sum);
}
void scale_array(float *a, int n, float s)
{
    int i;
src/utils.h
@@ -28,6 +28,7 @@
float sum_array(float *a, int n);
float mean_array(float *a, int n);
float variance_array(float *a, int n);
float mag_array(float *a, int n);
float **one_hot_encode(float *a, int n, int k);
float sec(clock_t clocks);
#endif