Joseph Redmon
2016-06-14 8322a58cf69e79ff1847634ede1a2a1b2e71e6d6
hate warnings
9 files modified
183 ■■■■ changed files
src/art.c 2 ●●● patch | view | raw | blame | history
src/classifier.c 4 ●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 24 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 110 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.h 3 ●●●●● patch | view | raw | blame | history
src/detection_layer.c 7 ●●●●● patch | view | raw | blame | history
src/go.c 21 ●●●● patch | view | raw | blame | history
src/network.c 5 ●●●●● patch | view | raw | blame | history
src/rnn.c 7 ●●●●● patch | view | raw | blame | history
src/art.c
@@ -53,7 +53,7 @@
        printf("[");
    int upper = 30;
        for(i = 0; i < upper; ++i){
            printf("%s", ((i+.5) < score*upper) ? "\u2588" : " ");
            printf("%c", ((i+.5) < score*upper) ? 219 : ' ');
        }
        printf("]\n");
src/classifier.c
@@ -51,7 +51,7 @@
    }
    if(clear) *net.seen = 0;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = net.batch;
    int imgs = net.batch*net.subdivisions;
    list *options = read_data_cfg(datacfg);
@@ -338,10 +338,10 @@
{
    int i, j;
    network net = parse_network_cfg(filename);
    set_batch_network(&net, 1);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    srand(time(0));
    list *options = read_data_cfg(datacfg);
src/convolutional_kernels.cu
@@ -72,10 +72,6 @@
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    int i;
    int m = l.n;
    int k = l.size*l.size*l.c;
    int n = convolutional_out_height(l)*
        convolutional_out_width(l);
    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
    if(l.binary){
@@ -109,6 +105,9 @@
                l.output_gpu);
#else
    int m = l.n;
    int k = l.size*l.size*l.c;
    int n = l.out_w*l.out_h;
    for(i = 0; i < l.batch; ++i){
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
        float * a = l.filters_gpu;
@@ -121,23 +120,18 @@
    if (l.batch_normalize) {
        forward_batchnorm_layer_gpu(l, state);
    }
    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
    activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
    activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
    //if(l.dot > 0) dot_error_gpu(l);
    if(l.binary || l.xnor) swap_binary(&l);
}
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    int m = l.n;
    int n = l.size*l.size*l.c;
    int k = convolutional_out_height(l)*
        convolutional_out_width(l);
    gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
    gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
    if(l.batch_normalize){
        backward_batchnorm_layer_gpu(l, state);
@@ -181,6 +175,10 @@
    }
#else
    int m = l.n;
    int n = l.size*l.size*l.c;
    int k = l.out_w*l.out_h;
    int i;
    for(i = 0; i < l.batch; ++i){
        float * a = l.delta_gpu;
src/convolutional_layer.c
@@ -14,6 +14,7 @@
#ifndef AI2
#define AI2 0
void forward_xnor_layer(layer l, network_state state);
#endif
void swap_binary(convolutional_layer *l)
@@ -127,6 +128,47 @@
#endif
}
#ifdef GPU
#ifdef CUDNN
void cudnn_convolutional_setup(layer *l)
{
    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
    cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
    cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
    int padding = l->pad ? l->size/2 : 0;
    cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
            l->srcTensorDesc,
            l->filterDesc,
            l->convDesc,
            l->dstTensorDesc,
            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
            0,
            &l->fw_algo);
    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
            l->filterDesc,
            l->ddstTensorDesc,
            l->convDesc,
            l->dsrcTensorDesc,
            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
            0,
            &l->bd_algo);
    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
            l->srcTensorDesc,
            l->ddstTensorDesc,
            l->convDesc,
            l->dfilterDesc,
            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
            0,
            &l->bf_algo);
}
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
{
    int i;
@@ -231,39 +273,7 @@
    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
    cudnnCreateFilterDescriptor(&l.dfilterDesc);
    cudnnCreateConvolutionDescriptor(&l.convDesc);
    cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
    cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
    cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
    cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
    cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
    cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
    int padding = l.pad ? l.size/2 : 0;
    cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
            l.srcTensorDesc,
            l.filterDesc,
            l.convDesc,
            l.dstTensorDesc,
            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
            0,
            &l.fw_algo);
    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
            l.filterDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dsrcTensorDesc,
            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
            0,
            &l.bd_algo);
    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
            l.srcTensorDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dfilterDesc,
            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
            0,
            &l.bf_algo);
    cudnn_convolutional_setup(&l);
#endif
#endif
    l.workspace_size = get_workspace_size(l);
@@ -335,39 +345,7 @@
    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
#ifdef CUDNN
    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
    cudnnSetFilter4dDescriptor(l->dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w);
    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
    cudnnSetFilter4dDescriptor(l->filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
    int padding = l->pad ? l->size/2 : 0;
    cudnnSetConvolution2dDescriptor(l->convDesc, padding, padding, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
            l->srcTensorDesc,
            l->filterDesc,
            l->convDesc,
            l->dstTensorDesc,
            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
            0,
            &l->fw_algo);
    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
            l->filterDesc,
            l->ddstTensorDesc,
            l->convDesc,
            l->dsrcTensorDesc,
            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
            0,
            &l->bd_algo);
    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
            l->srcTensorDesc,
            l->ddstTensorDesc,
            l->convDesc,
            l->dfilterDesc,
            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
            0,
            &l->bf_algo);
    cudnn_convolutional_setup(l);
#endif
#endif
    l->workspace_size = get_workspace_size(*l);
src/convolutional_layer.h
@@ -19,6 +19,9 @@
void add_bias_gpu(float *output, float *biases, int batch, int n, int size);
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
#ifdef CUDNN
void cudnn_convolutional_setup(layer *l);
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary, int xnor);
src/detection_layer.c
@@ -133,6 +133,9 @@
                        best_index = 0;
                    }
                }
                if(1 && *(state.net.seen) < 100000){
                    best_index = rand()%l.n;
                }
                int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords;
                int tbox_index = truth_index + 1 + l.classes;
@@ -181,7 +184,6 @@
            for (b = 0; b < l.batch; ++b) {
                int index = b*l.inputs;
                for (i = 0; i < locations; ++i) {
                    int truth_index = (b*locations + i)*(1+l.coords+l.classes);
                    for (j = 0; j < l.n; ++j) {
                        int p_index = index + locations*l.classes + i*l.n + j;
                        costs[b*locations*l.n + i*l.n + j] = l.delta[p_index]*l.delta[p_index];
@@ -194,7 +196,6 @@
            for (b = 0; b < l.batch; ++b) {
                int index = b*l.inputs;
                for (i = 0; i < locations; ++i) {
                    int truth_index = (b*locations + i)*(1+l.coords+l.classes);
                    for (j = 0; j < l.n; ++j) {
                        int p_index = index + locations*l.classes + i*l.n + j;
                        if (l.delta[p_index]*l.delta[p_index] < cutoff) l.delta[p_index] = 0;
@@ -233,7 +234,7 @@
        cuda_pull_array(state.truth, truth_cpu, num_truth);
    }
    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
    network_state cpu_state;
    network_state cpu_state = state;
    cpu_state.train = state.train;
    cpu_state.truth = truth_cpu;
    cpu_state.input = in_cpu;
src/go.c
@@ -217,7 +217,7 @@
    }
    fprintf(stream, "\n");
    for(j = 0; j < 19; ++j){
        fprintf(stream, "%2d ", (inverted) ? 19-j : j+1);
        fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
        for(i = 0; i < 19; ++i){
            int index = j*19 + i;
            if(indexes){
@@ -225,17 +225,26 @@
                for(n = 0; n < nind; ++n){
                    if(index == indexes[n]){
                        found = 1;
                        /*
                        if(n == 0) fprintf(stream, "\uff11");
                        else if(n == 1) fprintf(stream, "\uff12");
                        else if(n == 2) fprintf(stream, "\uff13");
                        else if(n == 3) fprintf(stream, "\uff14");
                        else if(n == 4) fprintf(stream, "\uff15");
                        */
                        if(n == 0) fprintf(stream, " 1");
                        else if(n == 1) fprintf(stream, " 2");
                        else if(n == 2) fprintf(stream, " 3");
                        else if(n == 3) fprintf(stream, " 4");
                        else if(n == 4) fprintf(stream, " 5");
                    }
                }
                if(found) continue;
            }
            if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
            else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
            //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
            //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
            if(board[index]*-swap > 0) fprintf(stream, " O");
            else if(board[index]*-swap < 0) fprintf(stream, " X");
            else fprintf(stream, "  ");
        }
        fprintf(stream, "\n");
@@ -640,8 +649,10 @@
            col = index % 19;
            printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
        }
        if(color == 1) printf("\u25EF Enter move: ");
        else printf("\u25C9 Enter move: ");
        //if(color == 1) printf("\u25EF Enter move: ");
        //else printf("\u25C9 Enter move: ");
        if(color == 1) printf("X Enter move: ");
        else printf("O Enter move: ");
        char c;
        char *line = fgetl(stdin);
src/network.c
@@ -392,6 +392,11 @@
    int i;
    for(i = 0; i < net->n; ++i){
        net->layers[i].batch = b;
        #ifdef CUDNN
        if(net->layers[i].type == CONVOLUTIONAL){
            cudnn_convolutional_setup(net->layers + i);
        }
        #endif
    }
}
src/rnn.c
@@ -280,7 +280,7 @@
    printf("\n");
}
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file)
void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
{
    char **tokens = 0;
    if(token_file){
@@ -301,9 +301,8 @@
    int i, j;
    for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
    int c = 0;
    int len = strlen(seed);
    float *input = calloc(inputs, sizeof(float));
    float *out;
    float *out = 0;
    while((c = getc(stdin)) != EOF){
        input[c] = 1;
@@ -490,5 +489,5 @@
    else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
    else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
    else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
    else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, seed, temp, rseed, tokens);
    else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
}