Joseph Redmon
2016-03-16 cff59ba1353b79ec3b69059ce1b4f191540616fd
go updates
12 files modified
327 ■■■■ changed files
Makefile 4 ●●●● patch | view | raw | blame | history
cfg/go.test.cfg 46 ●●●●● patch | view | raw | blame | history
src/classifier.c 2 ●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 11 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 79 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.h 3 ●●●●● patch | view | raw | blame | history
src/gemm.c 22 ●●●●● patch | view | raw | blame | history
src/gemm.h 5 ●●●●● patch | view | raw | blame | history
src/go.c 60 ●●●● patch | view | raw | blame | history
src/image.c 14 ●●●●● patch | view | raw | blame | history
src/layer.h 1 ●●●● patch | view | raw | blame | history
src/parser.c 80 ●●●●● patch | view | raw | blame | history
Makefile
@@ -1,5 +1,5 @@
GPU=0
OPENCV=0
GPU=1
OPENCV=1
DEBUG=0
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20 
cfg/go.test.cfg
@@ -10,11 +10,11 @@
learning_rate=0.1
max_batches = 0
policy=steps
steps=50000, 90000
scales=.1, .1
steps=50000
scales=.1
[convolutional]
filters=256
filters=512
size=3
stride=1
pad=1
@@ -23,6 +23,14 @@
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=leaky
batch_normalize=1
[convolutional]
filters=512
size=3
stride=1
pad=1
@@ -31,6 +39,14 @@
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=leaky
batch_normalize=1
[convolutional]
filters=512
size=3
stride=1
pad=1
@@ -39,6 +55,14 @@
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=leaky
batch_normalize=1
[convolutional]
filters=512
size=3
stride=1
pad=1
@@ -47,6 +71,14 @@
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=leaky
batch_normalize=1
[convolutional]
filters=512
size=3
stride=1
pad=1
@@ -54,6 +86,14 @@
batch_normalize=1
[convolutional]
filters=256
size=1
stride=1
pad=1
activation=leaky
batch_normalize=1
[convolutional]
filters=1
size=1
stride=1
src/classifier.c
@@ -410,7 +410,7 @@
    char **labels = get_labels(label_list);
    list *plist = get_paths(valid_list);
    int scales[] = {160, 192, 224, 288, 320, 352, 384};
    int scales[] = {192, 224, 288, 320, 352};
    int nscales = sizeof(scales)/sizeof(scales[0]);
    char **paths = (char **)list_to_array(plist);
src/convolutional_kernels.cu
@@ -65,9 +65,9 @@
    }
}
void binarize_filters_gpu(float *filters, int n, int size, float *mean)
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
{
    binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, mean);
    binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
    check_error(cudaPeekAtLastError());
}
@@ -161,13 +161,6 @@
    check_error(cudaPeekAtLastError());
}
void swap_binary(convolutional_layer *l)
{
    float *swap = l->filters_gpu;
    l->filters_gpu = l->binary_filters_gpu;
    l->binary_filters_gpu = swap;
}
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    int i;
src/convolutional_layer.c
@@ -7,6 +7,52 @@
#include <stdio.h>
#include <time.h>
void swap_binary(convolutional_layer *l)
{
    float *swap = l->filters;
    l->filters = l->binary_filters;
    l->binary_filters = swap;
    #ifdef GPU
    swap = l->filters_gpu;
    l->filters_gpu = l->binary_filters_gpu;
    l->binary_filters_gpu = swap;
    #endif
}
void binarize_filters2(float *filters, int n, int size, char *binary, float *scales)
{
    int i, k, f;
    for(f = 0; f < n; ++f){
        float mean = 0;
        for(i = 0; i < size; ++i){
            mean += fabs(filters[f*size + i]);
        }
        mean = mean / size;
        scales[f] = mean;
        for(i = 0; i < size/8; ++i){
            binary[f*size + i] = (filters[f*size + i] > 0) ? 1 : 0;
            for(k = 0; k < 8; ++k){
            }
        }
    }
}
void binarize_filters(float *filters, int n, int size, float *binary)
{
    int i, f;
    for(f = 0; f < n; ++f){
        float mean = 0;
        for(i = 0; i < size; ++i){
            mean += fabs(filters[f*size + i]);
        }
        mean = mean / size;
        for(i = 0; i < size; ++i){
            binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
        }
    }
}
int convolutional_out_height(convolutional_layer l)
{
    int h = l.h;
@@ -139,6 +185,8 @@
    if(binary){
        l.binary_filters = calloc(c*n*size*size, sizeof(float));
        l.cfilters = calloc(c*n*size*size, sizeof(char));
        l.scales = calloc(n, sizeof(float));
    }
    if(batch_normalize){
@@ -295,13 +343,42 @@
    }
}
void forward_convolutional_layer(const convolutional_layer l, network_state state)
void forward_convolutional_layer(convolutional_layer l, network_state state)
{
    int out_h = convolutional_out_height(l);
    int out_w = convolutional_out_width(l);
    int i;
    fill_cpu(l.outputs*l.batch, 0, l.output, 1);
    /*
    if(l.binary){
        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
        binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
        swap_binary(&l);
    }
    */
    if(l.binary){
        int m = l.n;
        int k = l.size*l.size*l.c;
        int n = out_h*out_w;
        char  *a = l.cfilters;
        float *b = l.col_image;
        float *c = l.output;
        for(i = 0; i < l.batch; ++i){
            im2col_cpu(state.input, l.c, l.h, l.w,
                    l.size, l.stride, l.pad, b);
            gemm_bin(m,n,k,1,a,k,b,n,c,n);
            c += n*m;
            state.input += l.c*l.h*l.w;
        }
        scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
        add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
        activate_array(l.output, m*n*l.batch, l.activation);
        return;
    }
    int m = l.n;
    int k = l.size*l.size*l.c;
src/convolutional_layer.h
@@ -27,6 +27,9 @@
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
void update_convolutional_layer(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
void binarize_filters(float *filters, int n, int size, float *binary);
void swap_binary(convolutional_layer *l);
void binarize_filters2(float *filters, int n, int size, char *binary, float *scales);
void backward_convolutional_layer(convolutional_layer layer, network_state state);
src/gemm.c
@@ -5,6 +5,28 @@
#include <stdio.h>
#include <math.h>
void gemm_bin(int M, int N, int K, float ALPHA,
        char  *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            char A_PART = A[i*lda+k];
            if(A_PART){
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] += B[k*ldb+j];
                }
            } else {
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] -= B[k*ldb+j];
                }
            }
        }
    }
}
float *random_matrix(int rows, int cols)
{
    int i;
src/gemm.h
@@ -1,6 +1,11 @@
#ifndef GEMM_H
#define GEMM_H
void gemm_bin(int M, int N, int K, float ALPHA,
        char  *A, int lda,
        float *B, int ldb,
        float *C, int ldc);
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
                    float *A, int lda, 
                    float *B, int ldb,
src/go.c
@@ -10,6 +10,7 @@
int inverted = 1;
int noi = 1;
static const int nind = 5;
void train_go(char *cfgfile, char *weightfile)
{
@@ -147,12 +148,14 @@
            int index = j*19 + i;
            if(indexes){
                int found = 0;
                for(n = 0; n < 3; ++n){
                for(n = 0; n < nind; ++n){
                    if(index == indexes[n]){
                        found = 1;
                        if(n == 0) printf("\uff11");
                        else if(n == 1) printf("\uff12");
                        else if(n == 2) printf("\uff13");
                        else if(n == 3) printf("\uff14");
                        else if(n == 4) printf("\uff15");
                    }
                }
                if(found) continue;
@@ -211,59 +214,56 @@
            if(board[i]) move[i] = 0;
        }
        int indexes[3];
        int indexes[nind];
        int row, col;
        top_k(move, 19*19, 3, indexes);
        top_k(move, 19*19, nind, indexes);
        print_board(board, color, indexes);
        for(i = 0; i < 3; ++i){
        for(i = 0; i < nind; ++i){
            int index = indexes[i];
            row = index / 19;
            col = index % 19;
            printf("Suggested: %c %d, %.2f%%\n", col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
            printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
        }
        int index = indexes[0];
        int rec_row = index / 19;
        int rec_col = index % 19;
        if(color == 1) printf("\u25EF Enter move: ");
        else printf("\u25C9 Enter move: ");
        char c;
        char *line = fgetl(stdin);
        int num = sscanf(line, "%c %d", &c, &row);
        if (strlen(line) == 0){
            row = rec_row;
            col = rec_col;
        int picked = 1;
        int dnum = sscanf(line, "%d", &picked);
        int cnum = sscanf(line, "%c", &c);
        if (strlen(line) == 0 || dnum) {
            --picked;
            if (picked < nind){
                int index = indexes[picked];
                row = index / 19;
                col = index % 19;
            board[row*19 + col] = 1;
        }else if (c < 'A' || c > 'T'){
            if (c == 'p'){
                flip_board(board);
                color = -color;
                free(line);
                continue;
            }
        } else if (cnum){
            if (c <= 'T' && c >= 'A'){
                int num = sscanf(line, "%c %d", &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 2) board[row*19 + col] = 1;
            } else if (c == 'p') {
                // Pass
            } else if(c=='b' || c == 'w'){
                char g;
                num = sscanf(line, "%c %c %d", &g, &c, &row);
                int num = sscanf(line, "%c %c %d", &g, &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
            }else{
            } else if(c == 'c'){
                char g;
                num = sscanf(line, "%c %c %d", &g, &c, &row);
                int num = sscanf(line, "%c %c %d", &g, &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 3) board[row*19 + col] = 0;
            }
        } else if(num == 2){
            row = (inverted)?19 - row : row-1;
            col = c - 'A';
            if (col > 7 && noi) col -= 1;
            board[row*19 + col] = 1;
        }else{
            free(line);
            continue;
        }
        free(line);
        update_board(board);
src/image.c
@@ -676,6 +676,17 @@
        }
    }
    image binarize_image(image im)
    {
        image c = copy_image(im);
        int i;
        for(i = 0; i < im.w * im.h * im.c; ++i){
            if(c.data[i] > .5) c.data[i] = 1;
            else c.data[i] = 0;
        }
        return c;
    }
    void saturate_image(image im, float sat)
    {
        rgb_to_hsv(im);
@@ -798,6 +809,8 @@
        image exp5 = copy_image(im);
        exposure_image(exp5, .5);
        image bin = binarize_image(im);
#ifdef GPU
        image r = resize_image(im, im.w, im.h);
        image black = make_image(im.w*2 + 3, im.h*2 + 3, 9);
@@ -818,6 +831,7 @@
#endif
        show_image(im, "Original");
        show_image(bin,  "Binary");
        show_image(gray, "Gray");
        show_image(sat2, "Saturation-2");
        show_image(sat5, "Saturation-.5");
src/layer.h
@@ -92,6 +92,7 @@
    float *rand;
    float *cost;
    float *filters;
    char  *cfilters;
    float *filter_updates;
    float *state;
src/parser.c
@@ -730,8 +730,44 @@
    fclose(fp);
}
void save_convolutional_weights_binary(layer l, FILE *fp)
{
#ifdef GPU
    if(gpu_index >= 0){
        pull_convolutional_layer(l);
    }
#endif
    binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
    int size = l.c*l.size*l.size;
    int i, j, k;
    fwrite(l.biases, sizeof(float), l.n, fp);
    if (l.batch_normalize){
        fwrite(l.scales, sizeof(float), l.n, fp);
        fwrite(l.rolling_mean, sizeof(float), l.n, fp);
        fwrite(l.rolling_variance, sizeof(float), l.n, fp);
    }
    for(i = 0; i < l.n; ++i){
        float mean = l.binary_filters[i*size];
        if(mean < 0) mean = -mean;
        fwrite(&mean, sizeof(float), 1, fp);
        for(j = 0; j < size/8; ++j){
            int index = i*size + j*8;
            unsigned char c = 0;
            for(k = 0; k < 8; ++k){
                if (j*8 + k >= size) break;
                if (l.binary_filters[index + k] > 0) c = (c | 1<<k);
            }
            fwrite(&c, sizeof(char), 1, fp);
        }
    }
}
void save_convolutional_weights(layer l, FILE *fp)
{
    if(l.binary){
        //save_convolutional_weights_binary(l, fp);
        //return;
    }
#ifdef GPU
    if(gpu_index >= 0){
        pull_convolutional_layer(l);
@@ -843,27 +879,55 @@
#endif
}
void load_convolutional_weights_binary(layer l, FILE *fp)
{
    fread(l.biases, sizeof(float), l.n, fp);
    if (l.batch_normalize && (!l.dontloadscales)){
        fread(l.scales, sizeof(float), l.n, fp);
        fread(l.rolling_mean, sizeof(float), l.n, fp);
        fread(l.rolling_variance, sizeof(float), l.n, fp);
    }
    int size = l.c*l.size*l.size;
    int i, j, k;
    for(i = 0; i < l.n; ++i){
        float mean = 0;
        fread(&mean, sizeof(float), 1, fp);
        for(j = 0; j < size/8; ++j){
            int index = i*size + j*8;
            unsigned char c = 0;
            fread(&c, sizeof(char), 1, fp);
            for(k = 0; k < 8; ++k){
                if (j*8 + k >= size) break;
                l.filters[index + k] = (c & 1<<k) ? mean : -mean;
            }
        }
    }
    binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
#ifdef GPU
    if(gpu_index >= 0){
        push_convolutional_layer(l);
    }
#endif
}
void load_convolutional_weights(layer l, FILE *fp)
{
    if(l.binary){
        //load_convolutional_weights_binary(l, fp);
        //return;
    }
    int num = l.n*l.c*l.size*l.size;
    fread(l.biases, sizeof(float), l.n, fp);
    if (l.batch_normalize && (!l.dontloadscales)){
        fread(l.scales, sizeof(float), l.n, fp);
        fread(l.rolling_mean, sizeof(float), l.n, fp);
        fread(l.rolling_variance, sizeof(float), l.n, fp);
        /*
        int i;
        for(i = 0; i < l.n; ++i){
            if(l.rolling_mean[i] > 1 || l.rolling_mean[i] < -1 || l.rolling_variance[i] > 1 || l.rolling_variance[i] < -1)
            printf("%f %f\n", l.rolling_mean[i], l.rolling_variance[i]);
        }
        */
    }
    fflush(stdout);
    fread(l.filters, sizeof(float), num, fp);
    if (l.flipped) {
        transpose_matrix(l.filters, l.c*l.size*l.size, l.n);
    }
    if (l.binary) binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.filters);
#ifdef GPU
    if(gpu_index >= 0){
        push_convolutional_layer(l);