Joseph Redmon
2015-01-13 aa5996d58e68edfbefe51061856aecd549dd09c4
Faster
13 files modified
185 ■■■■ changed files
Makefile 4 ●●●● patch | view | raw | blame | history
src/cnn.c 88 ●●●● patch | view | raw | blame | history
src/connected_layer.c 2 ●●● patch | view | raw | blame | history
src/connected_layer.h 1 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 18 ●●●● patch | view | raw | blame | history
src/convolutional_layer.cl 16 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.h 1 ●●●● patch | view | raw | blame | history
src/data.c 5 ●●●● patch | view | raw | blame | history
src/network.c 4 ●●●● patch | view | raw | blame | history
src/network_gpu.c 6 ●●●●● patch | view | raw | blame | history
src/opencl.c 4 ●●●● patch | view | raw | blame | history
src/parser.c 12 ●●●● patch | view | raw | blame | history
src/utils.c 24 ●●●●● patch | view | raw | blame | history
Makefile
@@ -13,8 +13,8 @@
endif
UNAME = $(shell uname)
OPTS=-Ofast -flto
#OPTS=-O3
#OPTS=-Ofast -flto
OPTS=-O3 -flto
ifeq ($(UNAME), Darwin)
COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
ifeq ($(GPU), 1)
src/cnn.c
@@ -71,11 +71,11 @@
}
void train_detection_net()
void train_detection_net(char *cfgfile)
{
    float avg_loss = 1;
    //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
    network net = parse_network_cfg("cfg/detnet.cfg");
    network net = parse_network_cfg(cfgfile);
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = 1024;
    srand(time(0));
@@ -115,6 +115,57 @@
    }
}
void validate_detection_net(char *cfgfile)
{
    network net = parse_network_cfg(cfgfile);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    list *plist = get_paths("/home/pjreddie/data/imagenet/detection.val");
    char **paths = (char **)list_to_array(plist);
    int m = plist->size;
    int i = 0;
    int splits = 50;
    int num = (i+1)*m/splits - i*m/splits;
    fprintf(stderr, "%d\n", m);
    data val, buffer;
    pthread_t load_thread = load_data_thread(paths, num, 0, 0, 245, 224, 224, &buffer);
    clock_t time;
    for(i = 1; i <= splits; ++i){
        time=clock();
        pthread_join(load_thread, 0);
        val = buffer;
        normalize_data_rows(val);
        num = (i+1)*m/splits - i*m/splits;
        char **part = paths+(i*m/splits);
        if(i != splits) load_thread = load_data_thread(part, num, 0, 0, 245, 224, 224, &buffer);
        fprintf(stderr, "Loaded: %lf seconds\n", sec(clock()-time));
        matrix pred = network_predict_data(net, val);
        int j, k;
        for(j = 0; j < pred.rows; ++j){
            for(k = 0; k < pred.cols; k += 5){
                if (pred.vals[j][k] > .005){
                    int index = k/5;
                    int r = index/7;
                    int c = index%7;
                    float y = (32.*(r + pred.vals[j][k+1]))/224.;
                    float x = (32.*(c + pred.vals[j][k+2]))/224.;
                    float h = (256.*(pred.vals[j][k+3]))/224.;
                    float w = (256.*(pred.vals[j][k+4]))/224.;
                    printf("%d %f %f %f %f %f\n", (i-1)*m/splits + j + 1, pred.vals[j][k], y, x, h, w);
                }
            }
        }
        time=clock();
        free_data(val);
    }
}
void train_imagenet_distributed(char *address)
{
    float avg_loss = 1;
@@ -159,10 +210,10 @@
    //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
    srand(time(0));
    network net = parse_network_cfg(cfgfile);
    //set_learning_network(&net, net.learning_rate, 0, .0005);
    set_learning_network(&net, net.learning_rate, 0, net.decay);
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = 1024;
    int i = 47900;
    int i = 77700;
    char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
    list *plist = get_paths("/data/imagenet/cls.train.list");
    char **paths = (char **)list_to_array(plist);
@@ -177,7 +228,9 @@
        time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
        normalize_data_rows(train);
        //normalize_data_rows(train);
        translate_data_rows(train, -128);
        scale_data_rows(train, 1./128);
        load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
@@ -265,8 +318,10 @@
    int i = 0;
    char *filename = "data/test.jpg";
    image im = load_image_color(filename, 224, 224);
    z_normalize_image(im);
    image im = load_image_color(filename, 256, 256);
    //z_normalize_image(im);
    translate_image(im, -128);
    scale_image(im, 1/128.);
    float *X = im.data;
    forward_network(net, X, 0, 1);
    for(i = 0; i < net.n; ++i){
@@ -352,9 +407,9 @@
        if(count%10 == 0){
            float test_acc = network_accuracy(net, test);
            printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,sec(clock()-time));
            char buff[256];
            sprintf(buff, "unikitty/cifar10_%d.cfg", count);
            save_network(net, buff);
            //char buff[256];
            //sprintf(buff, "unikitty/cifar10_%d.cfg", count);
            //save_network(net, buff);
        }else{
            printf("%d: Loss: %f, Time: %lf seconds\n", count, loss, sec(clock()-time));
        }
@@ -482,7 +537,7 @@
    cvWaitKey(0);
}
void test_gpu_net()
void test_correct_nist()
{
    srand(222222);
    network net = parse_network_cfg("cfg/nist.cfg");
@@ -523,11 +578,12 @@
    clock_t time;
    int count = 0;
    network net;
    srand(222222);
    net = parse_network_cfg("cfg/net.cfg");
    int imgs = net.batch;
    count = 0;
    srand(222222);
    net = parse_network_cfg("cfg/net.cfg");
    while(++count <= 5){
        time=clock();
        data train = load_data(paths, imgs, plist->size, labels, 1000, 256, 256);
@@ -624,9 +680,9 @@
    }
#endif
    if(0==strcmp(argv[1], "detection")) train_detection_net();
    else if(0==strcmp(argv[1], "cifar")) train_cifar10();
    if(0==strcmp(argv[1], "cifar")) train_cifar10();
    else if(0==strcmp(argv[1], "test_correct")) test_correct_alexnet();
    else if(0==strcmp(argv[1], "test_correct_nist")) test_correct_nist();
    else if(0==strcmp(argv[1], "test")) test_imagenet();
    else if(0==strcmp(argv[1], "server")) run_server();
@@ -638,6 +694,7 @@
        fprintf(stderr, "usage: %s <function> <filename>\n", argv[0]);
        return 0;
    }
    else if(0==strcmp(argv[1], "detection")) train_detection_net(argv[2]);
    else if(0==strcmp(argv[1], "nist")) train_nist(argv[2]);
    else if(0==strcmp(argv[1], "train")) train_imagenet(argv[2]);
    else if(0==strcmp(argv[1], "client")) train_imagenet_distributed(argv[2]);
@@ -646,6 +703,7 @@
    else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]);
    else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]);
    else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]);
    else if(0==strcmp(argv[1], "validetect")) validate_detection_net(argv[2]);
    else if(argc < 4){
        fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]);
        return 0;
src/connected_layer.c
@@ -162,7 +162,7 @@
    axpy_ongpu(layer.inputs*layer.outputs, -layer.decay, layer.weights_cl, 1, layer.weight_updates_cl, 1);
    axpy_ongpu(layer.inputs*layer.outputs, layer.learning_rate, layer.weight_updates_cl, 1, layer.weights_cl, 1);
    scal_ongpu(layer.inputs*layer.outputs, layer.momentum, layer.weight_updates_cl, 1);
    pull_connected_layer(layer);
    //pull_connected_layer(layer);
}
void forward_connected_layer_gpu(connected_layer layer, cl_mem input)
src/connected_layer.h
@@ -50,6 +50,7 @@
void backward_connected_layer_gpu(connected_layer layer, cl_mem input, cl_mem delta);
void update_connected_layer_gpu(connected_layer layer);
void push_connected_layer(connected_layer layer);
void pull_connected_layer(connected_layer layer);
#endif
#endif
src/convolutional_layer.c
@@ -170,7 +170,9 @@
    int n = layer.size*layer.size*layer.c;
    int k = convolutional_out_height(layer)*
        convolutional_out_width(layer);
    gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
    learn_bias_convolutional_layer(layer);
    if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
@@ -264,13 +266,18 @@
}
#ifdef GPU
#define BLOCK 32
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
cl_kernel get_convolutional_learn_bias_kernel()
{
    static int init = 0;
    static cl_kernel kernel;
    if(!init){
        kernel = get_kernel("src/convolutional_layer.cl", "learn_bias", 0);
        kernel = get_kernel("src/convolutional_layer.cl", "learn_bias", "-D BLOCK=" STR(BLOCK));
        init = 1;
    }
    return kernel;
@@ -291,9 +298,10 @@
    cl.error = clSetKernelArg(kernel, i++, sizeof(layer.bias_updates_cl), (void*) &layer.bias_updates_cl);
    check_error(cl);
    const size_t global_size[] = {layer.n};
    const size_t global_size[] = {layer.n*BLOCK};
    const size_t local_size[] = {BLOCK};
    cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, 0, 0, 0, 0);
    cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0, global_size, local_size, 0, 0, 0);
    check_error(cl);
}
@@ -302,7 +310,7 @@
    static int init = 0;
    static cl_kernel kernel;
    if(!init){
        kernel = get_kernel("src/convolutional_layer.cl", "bias", 0);
        kernel = get_kernel("src/convolutional_layer.cl", "bias", "-D BLOCK=" STR(BLOCK));
        init = 1;
    }
    return kernel;
@@ -410,7 +418,7 @@
    axpy_ongpu(size, -layer.decay, layer.filters_cl, 1, layer.filter_updates_cl, 1);
    axpy_ongpu(size, layer.learning_rate, layer.filter_updates_cl, 1, layer.filters_cl, 1);
    scal_ongpu(size, layer.momentum, layer.filter_updates_cl, 1);
    pull_convolutional_layer(layer);
    //pull_convolutional_layer(layer);
}
src/convolutional_layer.cl
@@ -11,15 +11,21 @@
__kernel void learn_bias(int batch, int n, int size, __global float *delta, __global float *bias_updates)
{
    __local float part[BLOCK];
    int i,b;
    int filter = get_global_id(0);
    int filter = get_group_id(0);
    int p = get_local_id(0);
    float sum = 0;
    for(b = 0; b < batch; ++b){
        for(i = 0; i < size; ++i){
            int index = i + size*(filter + n*b);
            sum += delta[index];
        for(i = 0; i < size; i += BLOCK){
            int index = p + i + size*(filter + n*b);
            sum += (index < size) ? delta[index] : 0;
        }
    }
    bias_updates[filter] += sum;
    part[p] = sum;
    barrier(CLK_LOCAL_MEM_FENCE);
    if(p == 0){
        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
    }
}
src/convolutional_layer.h
@@ -46,6 +46,7 @@
void backward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in, cl_mem delta_cl);
void update_convolutional_layer_gpu(convolutional_layer layer);
void push_convolutional_layer(convolutional_layer layer);
void pull_convolutional_layer(convolutional_layer layer);
#endif
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay);
src/data.c
@@ -72,11 +72,14 @@
{
    int i;
    memset(truth, 0, k*sizeof(float));
    int count = 0;
    for(i = 0; i < k; ++i){
        if(strstr(path, labels[i])){
            truth[i] = 1;
            ++count;
        }
    }
    if(count != 1) printf("%d, %s\n", count, path);
}
matrix load_image_paths(char **paths, int n, int h, int w)
@@ -111,7 +114,7 @@
{
    matrix y = make_matrix(n, k);
    int i;
    for(i = 0; i < n; ++i){
    for(i = 0; i < n && labels; ++i){
        fill_truth(paths[i], labels, k, y.vals[i]);
    }
    return y;
src/network.c
@@ -372,6 +372,10 @@
            cost_layer *layer = (cost_layer *)net->layers[i];
            layer->batch = b;
        }
        else if(net->types[i] == CROP){
            crop_layer *layer = (crop_layer *)net->layers[i];
            layer->batch = b;
        }
    }
}
src/network_gpu.c
@@ -24,6 +24,7 @@
{
    int i;
    for(i = 0; i < net.n; ++i){
        clock_t time = clock();
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            forward_convolutional_layer_gpu(layer, input);
@@ -59,6 +60,8 @@
            forward_crop_layer_gpu(layer, input);
            input = layer.output_cl;
        }
        check_error(cl);
        //printf("Forw %d %f\n", i, sec(clock() - time));
    }
}
@@ -68,6 +71,7 @@
    cl_mem prev_input;
    cl_mem prev_delta;
    for(i = net.n-1; i >= 0; --i){
        clock_t time = clock();
        if(i == 0){
            prev_input = input;
            prev_delta = 0;
@@ -99,6 +103,8 @@
            softmax_layer layer = *(softmax_layer *)net.layers[i];
            backward_softmax_layer_gpu(layer, prev_delta);
        }
        check_error(cl);
        //printf("Back %d %f\n", i, sec(clock() - time));
    }
}
src/opencl.c
@@ -18,7 +18,7 @@
void check_error(cl_info info)
{
   // clFinish(cl.queue);
    clFinish(cl.queue);
    if (info.error != CL_SUCCESS) {
        printf("\n Error number %d", info.error);
        abort();
@@ -144,7 +144,7 @@
void cl_setup()
{
    if(!cl.initialized){
        printf("initializing\n");
        fprintf(stderr, "Initializing OpenCL\n");
        cl = cl_init(gpu_index);
    }
}
src/parser.c
@@ -16,6 +16,7 @@
#include "list.h"
#include "option_list.h"
#include "utils.h"
#include "opencl.h"
typedef struct{
    char *type;
@@ -387,8 +388,8 @@
int read_option(char *s, list *options)
{
    int i;
    int len = strlen(s);
    size_t i;
    size_t len = strlen(s);
    char *val = 0;
    for(i = 0; i < len; ++i){
        if(s[i] == '='){
@@ -416,7 +417,6 @@
        strip(line);
        switch(line[0]){
            case '[':
                printf("%s\n", line);
                current = malloc(sizeof(section));
                list_insert(sections, current);
                current->options = make_list();
@@ -441,6 +441,9 @@
void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count)
{
    #ifdef GPU
    if(gpu_index >= 0) pull_convolutional_layer(*l);
    #endif
    int i;
    fprintf(fp, "[convolutional]\n");
    if(count == 0) {
@@ -495,6 +498,9 @@
void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count)
{
    #ifdef GPU
    if(gpu_index >= 0) pull_connected_layer(*l);
    #endif
    int i;
    fprintf(fp, "[connected]\n");
    if(count == 0){
src/utils.c
@@ -3,6 +3,7 @@
#include <string.h>
#include <math.h>
#include <float.h>
#include <limits.h>
#include "utils.h"
@@ -64,8 +65,8 @@
list *split_str(char *s, char delim)
{
    int i;
    int len = strlen(s);
    size_t i;
    size_t len = strlen(s);
    list *l = make_list();
    list_insert(l, s);
    for(i = 0; i < len; ++i){
@@ -79,9 +80,9 @@
void strip(char *s)
{
    int i;
    int len = strlen(s);
    int offset = 0;
    size_t i;
    size_t len = strlen(s);
    size_t offset = 0;
    for(i = 0; i < len; ++i){
        char c = s[i];
        if(c==' '||c=='\t'||c=='\n') ++offset;
@@ -92,9 +93,9 @@
void strip_char(char *s, char bad)
{
    int i;
    int len = strlen(s);
    int offset = 0;
    size_t i;
    size_t len = strlen(s);
    size_t offset = 0;
    for(i = 0; i < len; ++i){
        char c = s[i];
        if(c==bad) ++offset;
@@ -116,14 +117,17 @@
    size_t curr = strlen(line);
    while((line[curr-1] != '\n') && !feof(fp)){
        printf("%ld %ld\n", curr, size);
        if(curr == size-1){
        size *= 2;
        line = realloc(line, size*sizeof(char));
        if(!line) {
            printf("%ld\n", size);
            malloc_error();
        }
        fgets(&line[curr], size-curr, fp);
        }
        size_t readsize = size-curr;
        if(readsize > INT_MAX) readsize = INT_MAX-1;
        fgets(&line[curr], readsize, fp);
        curr = strlen(line);
    }
    if(line[curr-1] == '\n') line[curr-1] = '\0';