Joseph Redmon
2016-11-19 62235e9aa3d0c15d87d49bf340625d075cba3e65
cpu batch norm works
12 files modified
174 ■■■■■ changed files
Makefile 4 ●●● patch | view | raw | blame | history
cfg/darknet.cfg 1 ●●●● patch | view | raw | blame | history
cfg/yolo.cfg 2 ●●● patch | view | raw | blame | history
src/batchnorm_layer.c 18 ●●●●● patch | view | raw | blame | history
src/cifar.c 23 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.c 65 ●●●● patch | view | raw | blame | history
src/detector.c 12 ●●●●● patch | view | raw | blame | history
src/image.c 29 ●●●● patch | view | raw | blame | history
src/image.h 2 ●●●●● patch | view | raw | blame | history
src/maxpool_layer.c 8 ●●●● patch | view | raw | blame | history
src/maxpool_layer_kernels.cu 8 ●●●● patch | view | raw | blame | history
src/reorg_layer.c 2 ●●● patch | view | raw | blame | history
Makefile
@@ -50,7 +50,7 @@
OBJS = $(addprefix $(OBJDIR), $(OBJ))
DEPS = $(wildcard src/*.h) Makefile
all: obj results $(EXEC)
all: obj backup results $(EXEC)
$(EXEC): $(OBJS)
    $(CC) $(COMMON) $(CFLAGS) $^ -o $@ $(LDFLAGS)
@@ -63,6 +63,8 @@
obj:
    mkdir -p obj
backup:
    mkdir -p backup
results:
    mkdir -p results
cfg/darknet.cfg
@@ -84,6 +84,7 @@
[maxpool]
size=2
stride=2
padding=1
[convolutional]
batch_normalize=1
cfg/yolo.cfg
@@ -1,8 +1,8 @@
[net]
batch=64
subdivisions=8
height=416
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
src/batchnorm_layer.c
@@ -129,15 +129,31 @@
    if(state.train){
        mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);   
        variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);   
        scal_cpu(l.out_c, .99, l.rolling_mean, 1);
        axpy_cpu(l.out_c, .01, l.mean, 1, l.rolling_mean, 1);
        scal_cpu(l.out_c, .99, l.rolling_variance, 1);
        axpy_cpu(l.out_c, .01, l.variance, 1, l.rolling_variance, 1);
        copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
        normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w);   
        copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1);
    } else {
        normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.out_c, l.out_h*l.out_w);
    }
    scale_bias(l.output, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
}
void backward_batchnorm_layer(const layer layer, network_state state)
void backward_batchnorm_layer(const layer l, network_state state)
{
    backward_scale_cpu(l.x_norm, l.delta, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates);
    scale_bias(l.delta, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
    mean_delta_cpu(l.delta, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta);
    variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta);
    normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.out_c, l.out_w*l.out_h, l.delta);
    if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1);
}
#ifdef GPU
src/cifar.c
@@ -166,6 +166,28 @@
    free_data(test);
}
void extract_cifar()
{
char *labels[] = {"airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"};
    int i;
    data train = load_all_cifar10();
    data test = load_cifar10_data("data/cifar/cifar-10-batches-bin/test_batch.bin");
    for(i = 0; i < train.X.rows; ++i){
        image im = float_to_image(32, 32, 3, train.X.vals[i]);
        int class = max_index(train.y.vals[i], 10);
        char buff[256];
        sprintf(buff, "data/cifar/train/%d_%s",i,labels[class]);
        save_image_png(im, buff);
    }
    for(i = 0; i < test.X.rows; ++i){
        image im = float_to_image(32, 32, 3, test.X.vals[i]);
        int class = max_index(test.y.vals[i], 10);
        char buff[256];
        sprintf(buff, "data/cifar/test/%d_%s",i,labels[class]);
        save_image_png(im, buff);
    }
}
void test_cifar_csv(char *filename, char *weightfile)
{
    network net = parse_network_cfg(filename);
@@ -243,6 +265,7 @@
    char *cfg = argv[3];
    char *weights = (argc > 4) ? argv[4] : 0;
    if(0==strcmp(argv[2], "train")) train_cifar(cfg, weights);
    else if(0==strcmp(argv[2], "extract")) extract_cifar();
    else if(0==strcmp(argv[2], "distill")) train_cifar_distill(cfg, weights);
    else if(0==strcmp(argv[2], "test")) test_cifar(cfg, weights);
    else if(0==strcmp(argv[2], "multi")) test_cifar_multi(cfg, weights);
src/convolutional_layer.c
@@ -206,8 +206,8 @@
    l.outputs = l.out_h * l.out_w * l.out_c;
    l.inputs = l.w * l.h * l.c;
    l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
    l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
    l.output = calloc(l.batch*l.outputs, sizeof(float));
    l.delta  = calloc(l.batch*l.outputs, sizeof(float));
    l.forward = forward_convolutional_layer;
    l.backward = backward_convolutional_layer;
@@ -232,8 +232,13 @@
        l.mean = calloc(n, sizeof(float));
        l.variance = calloc(n, sizeof(float));
        l.mean_delta = calloc(n, sizeof(float));
        l.variance_delta = calloc(n, sizeof(float));
        l.rolling_mean = calloc(n, sizeof(float));
        l.rolling_variance = calloc(n, sizeof(float));
        l.x = calloc(l.batch*l.outputs, sizeof(float));
        l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
    }
    if(adam){
        l.adam = 1;
@@ -357,17 +362,19 @@
    l->outputs = l->out_h * l->out_w * l->out_c;
    l->inputs = l->w * l->h * l->c;
    l->output = realloc(l->output,
            l->batch*out_h * out_w * l->n*sizeof(float));
    l->delta  = realloc(l->delta,
            l->batch*out_h * out_w * l->n*sizeof(float));
    l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
    l->delta  = realloc(l->delta,  l->batch*l->outputs*sizeof(float));
    if(l->batch_normalize){
        l->x = realloc(l->x, l->batch*l->outputs*sizeof(float));
        l->x_norm  = realloc(l->x_norm, l->batch*l->outputs*sizeof(float));
    }
#ifdef GPU
    cuda_free(l->delta_gpu);
    cuda_free(l->output_gpu);
    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
    l->delta_gpu =  cuda_make_array(l->delta,  l->batch*l->outputs);
    l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
    if(l->batch_normalize){
        cuda_free(l->x_gpu);
@@ -423,41 +430,8 @@
    int out_w = convolutional_out_width(l);
    int i;
    fill_cpu(l.outputs*l.batch, 0, l.output, 1);
    /*
       if(l.binary){
       binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
       binarize_weights2(l.weights, l.n, l.c*l.size*l.size, l.cweights, l.scales);
       swap_binary(&l);
       }
     */
    /*
       if(l.binary){
       int m = l.n;
       int k = l.size*l.size*l.c;
       int n = out_h*out_w;
       char  *a = l.cweights;
       float *b = state.workspace;
       float *c = l.output;
       for(i = 0; i < l.batch; ++i){
       im2col_cpu(state.input, l.c, l.h, l.w,
       l.size, l.stride, l.pad, b);
       gemm_bin(m,n,k,1,a,k,b,n,c,n);
       c += n*m;
       state.input += l.c*l.h*l.w;
       }
       scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
       add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
       activate_array(l.output, m*n*l.batch, l.activation);
       return;
       }
     */
    if(l.xnor){
        binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
        swap_binary(&l);
@@ -469,10 +443,6 @@
    int k = l.size*l.size*l.c;
    int n = out_h*out_w;
    if (l.xnor && l.c%32 == 0 && AI2) {
        forward_xnor_layer(l, state);
        printf("xnor\n");
    } else {
        float *a = l.weights;
        float *b = state.workspace;
@@ -485,7 +455,6 @@
            c += n*m;
            state.input += l.c*l.h*l.w;
        }
    }
    if(l.batch_normalize){
        forward_batchnorm_layer(l, state);
@@ -507,6 +476,10 @@
    gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
    backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
    if(l.batch_normalize){
        backward_batchnorm_layer(l, state);
    }
    for(i = 0; i < l.batch; ++i){
        float *a = l.delta + i*m*k;
        float *b = state.workspace;
src/detector.c
@@ -444,7 +444,6 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    layer l = net.layers[net.n-1];
    set_batch_network(&net, 1);
    srand(2222222);
    clock_t time;
@@ -452,9 +451,6 @@
    char *input = buff;
    int j;
    float nms=.4;
    box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
    float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
    for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
    while(1){
        if(filename){
            strncpy(input, filename, 256);
@@ -467,6 +463,12 @@
        }
        image im = load_image_color(input,0,0);
        image sized = resize_image(im, net.w, net.h);
        layer l = net.layers[net.n-1];
        box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
        float **probs = calloc(l.w*l.h*l.n, sizeof(float *));
        for(j = 0; j < l.w*l.h*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
        float *X = sized.data;
        time=clock();
        network_predict(net, X);
@@ -479,6 +481,8 @@
        free_image(im);
        free_image(sized);
        free(boxes);
        free_ptrs((void **)probs, l.w*l.h*l.n);
#ifdef OPENCV
        cvWaitKey(0);
        cvDestroyAllWindows();
src/image.c
@@ -532,11 +532,8 @@
}
#endif
void save_image(image im, const char *name)
void save_image_png(image im, const char *name)
{
#ifdef OPENCV
    save_image_jpg(im, name);
#else
    char buff[256];
    //sprintf(buff, "%s (%d)", name, windows);
    sprintf(buff, "%s.png", name);
@@ -550,6 +547,14 @@
    int success = stbi_write_png(buff, im.w, im.h, im.c, data, im.w*im.c);
    free(data);
    if(!success) fprintf(stderr, "Failed to write image %s\n", buff);
}
void save_image(image im, const char *name)
{
#ifdef OPENCV
    save_image_jpg(im, name);
#else
    save_image_png(im, name);
#endif
}
@@ -748,6 +753,22 @@
#endif
}
image resize_max(image im, int max)
{
    int w = im.w;
    int h = im.h;
    if(w > h){
        h = (h * max) / w;
        w = max;
    } else {
        w = (w * max) / h;
        h = max;
    }
    if(w == im.w && h == im.h) return im;
    image resized = resize_image(im, w, h);
    return resized;
}
image resize_min(image im, int min)
{
    int w = im.w;
src/image.h
@@ -31,6 +31,7 @@
void random_distort_image(image im, float hue, float saturation, float exposure);
image resize_image(image im, int w, int h);
image resize_min(image im, int min);
image resize_max(image im, int max);
void translate_image(image m, float s);
void normalize_image(image p);
image rotate_image(image m, float rad);
@@ -55,6 +56,7 @@
void show_image(image p, const char *name);
void show_image_normalized(image im, const char *name);
void save_image_png(image im, const char *name);
void save_image(image p, const char *name);
void show_images(image *ims, int n, char *window);
void show_image_layers(image p, char *name);
src/maxpool_layer.c
@@ -27,8 +27,8 @@
    l.w = w;
    l.c = c;
    l.pad = padding;
    l.out_w = (w + 2*padding - size + 1)/stride + 1;
    l.out_h = (h + 2*padding - size + 1)/stride + 1;
    l.out_w = (w + 2*padding)/stride;
    l.out_h = (h + 2*padding)/stride;
    l.out_c = c;
    l.outputs = l.out_h * l.out_w * l.out_c;
    l.inputs = h*w*c;
@@ -57,8 +57,8 @@
    l->w = w;
    l->inputs = h*w*l->c;
    l->out_w = (w + 2*l->pad - l->size + 1)/l->stride + 1;
    l->out_h = (h + 2*l->pad - l->size + 1)/l->stride + 1;
    l->out_w = (w + 2*l->pad)/l->stride;
    l->out_h = (h + 2*l->pad)/l->stride;
    l->outputs = l->out_w * l->out_h * l->c;
    int output_size = l->outputs * l->batch;
src/maxpool_layer_kernels.cu
@@ -9,8 +9,8 @@
__global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
{
    int h = (in_h + 2*pad - size + 1)/stride + 1;
    int w = (in_w + 2*pad - size + 1)/stride + 1;
    int h = (in_h + 2*pad)/stride;
    int w = (in_w + 2*pad)/stride;
    int c = in_c;
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -49,8 +49,8 @@
__global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
{
    int h = (in_h + 2*pad - size + 1)/stride + 1;
    int w = (in_w + 2*pad - size + 1)/stride + 1;
    int h = (in_h + 2*pad)/stride;
    int w = (in_w + 2*pad)/stride;
    int c = in_c;
    int area = (size-1)/stride;
src/reorg_layer.c
@@ -4,7 +4,7 @@
#include <stdio.h>
layer make_reorg_layer(int batch, int h, int w, int c, int stride, int reverse)
layer make_reorg_layer(int batch, int w, int h, int c, int stride, int reverse)
{
    layer l = {0};
    l.type = REORG;