Joseph Redmon
2016-05-13 13209df7bb53de19aa3f82e870db11eb5b7587f1
art, cudnn
10 files modified
1 files added
316 ■■■■■ changed files
Makefile 9 ●●●● patch | view | raw | blame | history
src/art.c 76 ●●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 67 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 92 ●●●●● patch | view | raw | blame | history
src/cuda.c 25 ●●●● patch | view | raw | blame | history
src/cuda.h 20 ●●●● patch | view | raw | blame | history
src/darknet.c 3 ●●●●● patch | view | raw | blame | history
src/layer.h 13 ●●●●● patch | view | raw | blame | history
src/network.h 2 ●●●●● patch | view | raw | blame | history
src/network_kernels.cu 2 ●●●●● patch | view | raw | blame | history
src/parser.c 7 ●●●●● patch | view | raw | blame | history
Makefile
@@ -1,4 +1,5 @@
GPU=0
CUDNN=0
OPENCV=0
DEBUG=0
@@ -34,7 +35,13 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o
ifeq ($(CUDNN), 1)
COMMON+= -DCUDNN
CFLAGS+= -DCUDNN
LDFLAGS+= -lcudnn
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o coco_demo.o tag.o cifar.o yolo_demo.o go.o batchnorm_layer.o art.o
ifeq ($(GPU), 1) 
LDFLAGS+= -lstdc++ 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
src/art.c
New file
@@ -0,0 +1,76 @@
#include "network.h"
#include "utils.h"
#include "parser.h"
#include "option_list.h"
#include "blas.h"
#include "classifier.h"
#include <sys/time.h>
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#endif
void demo_art(char *cfgfile, char *weightfile, int cam_index)
{
#ifdef OPENCV
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    srand(2222222);
    CvCapture * cap;
    cap = cvCaptureFromCAM(cam_index);
    char *window = "ArtJudgementBot9000!!!";
    if(!cap) error("Couldn't connect to webcam.\n");
    cvNamedWindow(window, CV_WINDOW_NORMAL);
    cvResizeWindow(window, 512, 512);
    int i;
    int idx[] = {37, 401, 434};
    int n = sizeof(idx)/sizeof(idx[0]);
    while(1){
        image in = get_image_from_stream(cap);
        image in_s = resize_image(in, net.w, net.h);
        show_image(in, window);
        float *p = network_predict(net, in_s.data);
        printf("\033[2J");
        printf("\033[1;1H");
        float score = 0;
        for(i = 0; i < n; ++i){
            float s = p[idx[i]];
            if (s > score) score = s;
        }
        score = score;
        printf("I APPRECIATE THIS ARTWORK: %10.7f%%\n", score*100);
        printf("[");
    int upper = 30;
        for(i = 0; i < upper; ++i){
            printf("%s", ((i+.5) < score*upper) ? "\u2588" : " ");
        }
        printf("]\n");
        free_image(in_s);
        free_image(in);
        cvWaitKey(1);
    }
#endif
}
void run_art(int argc, char **argv)
{
    int cam_index = find_int_arg(argc, argv, "-c", 0);
    char *cfg = argv[2];
    char *weights = argv[3];
    demo_art(cfg, weights, cam_index);
}
src/convolutional_kernels.cu
@@ -85,7 +85,6 @@
    if(l.xnor){
        binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
        //binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
        swap_binary(&l);
        for(i = 0; i < l.batch; ++i){
            binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
@@ -93,13 +92,31 @@
        state.input = l.binary_input_gpu;
    }
#ifdef CUDNN
    float one = 1;
    cudnnConvolutionForward(cudnn_handle(),
                &one,
                l.srcTensorDesc,
                state.input,
                l.filterDesc,
                l.filters_gpu,
                l.convDesc,
                l.fw_algo,
                state.workspace,
                l.workspace_size,
                &one,
                l.dstTensorDesc,
                l.output_gpu);
#else
    for(i = 0; i < l.batch; ++i){
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
        float * a = l.filters_gpu;
        float * b = l.col_image_gpu;
        float * b = state.workspace;
        float * c = l.output_gpu;
        gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
    }
#endif
    if (l.batch_normalize) {
        forward_batchnorm_layer_gpu(l, state);
@@ -113,7 +130,6 @@
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    int i;
    int m = l.n;
    int n = l.size*l.size*l.c;
    int k = convolutional_out_height(l)*
@@ -128,26 +144,61 @@
    }
    if(l.xnor) state.input = l.binary_input_gpu;
#ifdef CUDNN
    float one = 1;
    cudnnConvolutionBackwardFilter(cudnn_handle(),
            &one,
            l.srcTensorDesc,
            state.input,
            l.ddstTensorDesc,
            l.delta_gpu,
            l.convDesc,
            l.bf_algo,
            state.workspace,
            l.workspace_size,
            &one,
            l.dfilterDesc,
            l.filter_updates_gpu);
    if(state.delta){
        cudnnConvolutionBackwardData(cudnn_handle(),
                &one,
                l.filterDesc,
                l.filters_gpu,
                l.ddstTensorDesc,
                l.delta_gpu,
                l.convDesc,
                l.bd_algo,
                state.workspace,
                l.workspace_size,
                &one,
                l.dsrcTensorDesc,
                state.delta);
    }
#else
    int i;
    for(i = 0; i < l.batch; ++i){
        float * a = l.delta_gpu;
        float * b = l.col_image_gpu;
        float * b = state.workspace;
        float * c = l.filter_updates_gpu;
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
        gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
        if(state.delta){
            if(l.binary || l.xnor) swap_binary(&l);
            float * a = l.filters_gpu;
            float * b = l.delta_gpu;
            float * c = l.col_image_gpu;
            float * c = state.workspace;
            gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
            col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
            col2im_ongpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
            if(l.binary || l.xnor) swap_binary(&l);
        }
    }
#endif
}
void pull_convolutional_layer(convolutional_layer layer)
src/convolutional_layer.c
@@ -88,6 +88,38 @@
    return float_to_image(w,h,c,l.delta);
}
#ifdef CUDNN
size_t get_workspace_size(layer l){
    size_t most = 0;
    size_t s = 0;
    cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
            l.srcTensorDesc,
            l.filterDesc,
            l.convDesc,
            l.dstTensorDesc,
            l.fw_algo,
            &s);
    if (s > most) most = s;
    cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
            l.srcTensorDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dfilterDesc,
            l.bf_algo,
            &s);
    if (s > most) most = s;
    cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
            l.filterDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dsrcTensorDesc,
            l.bd_algo,
            &s);
    if (s > most) most = s;
    return most;
}
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalize, int binary, int xnor)
{
    int i;
@@ -156,7 +188,7 @@
    l.scales_gpu = cuda_make_array(l.scales, n);
    l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
    l.workspace_size = out_h*out_w*size*size*c;
    l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
    l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
@@ -182,6 +214,50 @@
        l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
        l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
    }
#ifdef CUDNN
    cudnnCreateTensorDescriptor(&l.srcTensorDesc);
    cudnnCreateTensorDescriptor(&l.dstTensorDesc);
    cudnnCreateFilterDescriptor(&l.filterDesc);
    cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
    cudnnCreateFilterDescriptor(&l.dfilterDesc);
    cudnnCreateConvolutionDescriptor(&l.convDesc);
    cudnnSetTensor4dDescriptor(l.dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
    cudnnSetTensor4dDescriptor(l.ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
    cudnnSetFilter4dDescriptor(l.dfilterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
    cudnnSetTensor4dDescriptor(l.srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.c, l.h, l.w);
    cudnnSetTensor4dDescriptor(l.dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l.batch, l.out_c, l.out_h, l.out_w);
    cudnnSetFilter4dDescriptor(l.filterDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l.n, l.c, l.size, l.size);
    int padding = l.pad ? l.size/2 : 0;
    cudnnSetConvolution2dDescriptor(l.convDesc, padding, padding, l.stride, l.stride, 1, 1, CUDNN_CROSS_CORRELATION);
    cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
            l.srcTensorDesc,
            l.filterDesc,
            l.convDesc,
            l.dstTensorDesc,
            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
            0,
            &l.fw_algo);
    cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
            l.filterDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dsrcTensorDesc,
            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
            0,
            &l.bd_algo);
    cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
            l.srcTensorDesc,
            l.ddstTensorDesc,
            l.convDesc,
            l.dfilterDesc,
            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
            0,
            &l.bf_algo);
    l.workspace_size = get_workspace_size(l);
#endif
#endif
    l.activation = activation;
@@ -247,11 +323,9 @@
            l->batch*out_h * out_w * l->n*sizeof(float));
#ifdef GPU
    cuda_free(l->col_image_gpu);
    cuda_free(l->delta_gpu);
    cuda_free(l->output_gpu);
    l->col_image_gpu = cuda_make_array(l->col_image, out_h*out_w*l->size*l->size*l->c);
    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
#endif
@@ -299,12 +373,12 @@
    fill_cpu(l.outputs*l.batch, 0, l.output, 1);
    /*
    if(l.binary){
        binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
        binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
        swap_binary(&l);
    }
    */
       if(l.binary){
       binarize_filters(l.filters, l.n, l.c*l.size*l.size, l.binary_filters);
       binarize_filters2(l.filters, l.n, l.c*l.size*l.size, l.cfilters, l.scales);
       swap_binary(&l);
       }
     */
    if(l.binary){
        int m = l.n;
src/cuda.c
@@ -46,6 +46,19 @@
    return d;
}
#ifdef CUDNN
cudnnHandle_t cudnn_handle()
{
    static int init = 0;
    static cudnnHandle_t handle;
    if(!init) {
        cudnnCreate(&handle);
        init = 1;
    }
    return handle;
}
#endif
cublasHandle_t blas_handle()
{
    static int init = 0;
@@ -57,7 +70,7 @@
    return handle;
}
float *cuda_make_array(float *x, int n)
float *cuda_make_array(float *x, size_t n)
{
    float *x_gpu;
    size_t size = sizeof(float)*n;
@@ -71,7 +84,7 @@
    return x_gpu;
}
void cuda_random(float *x_gpu, int n)
void cuda_random(float *x_gpu, size_t n)
{
    static curandGenerator_t gen;
    static int init = 0;
@@ -84,7 +97,7 @@
    check_error(cudaPeekAtLastError());
}
float cuda_compare(float *x_gpu, float *x, int n, char *s)
float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
{
    float *tmp = calloc(n, sizeof(float));
    cuda_pull_array(x_gpu, tmp, n);
@@ -97,7 +110,7 @@
    return err;
}
int *cuda_make_int_array(int n)
int *cuda_make_int_array(size_t n)
{
    int *x_gpu;
    size_t size = sizeof(int)*n;
@@ -112,14 +125,14 @@
    check_error(status);
}
void cuda_push_array(float *x_gpu, float *x, int n)
void cuda_push_array(float *x_gpu, float *x, size_t n)
{
    size_t size = sizeof(float)*n;
    cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
    check_error(status);
}
void cuda_pull_array(float *x_gpu, float *x, int n)
void cuda_pull_array(float *x_gpu, float *x, size_t n)
{
    size_t size = sizeof(float)*n;
    cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);
src/cuda.h
@@ -11,16 +11,24 @@
#include "curand.h"
#include "cublas_v2.h"
#ifdef CUDNN
#include "cudnn.h"
#endif
void check_error(cudaError_t status);
cublasHandle_t blas_handle();
float *cuda_make_array(float *x, int n);
int *cuda_make_int_array(int n);
void cuda_push_array(float *x_gpu, float *x, int n);
void cuda_pull_array(float *x_gpu, float *x, int n);
float *cuda_make_array(float *x, size_t n);
int *cuda_make_int_array(size_t n);
void cuda_push_array(float *x_gpu, float *x, size_t n);
void cuda_pull_array(float *x_gpu, float *x, size_t n);
void cuda_free(float *x_gpu);
void cuda_random(float *x_gpu, int n);
float cuda_compare(float *x_gpu, float *x, int n, char *s);
void cuda_random(float *x_gpu, size_t n);
float cuda_compare(float *x_gpu, float *x, size_t n, char *s);
dim3 cuda_gridsize(size_t n);
#ifdef CUDNN
cudnnHandle_t cudnn_handle();
#endif
#endif
#endif
src/darknet.c
@@ -26,6 +26,7 @@
extern void run_tag(int argc, char **argv);
extern void run_cifar(int argc, char **argv);
extern void run_go(int argc, char **argv);
extern void run_art(int argc, char **argv);
void change_rate(char *filename, float scale, float add)
{
@@ -259,6 +260,8 @@
        run_coco(argc, argv);
    } else if (0 == strcmp(argv[1], "classifier")){
        run_classifier(argc, argv);
    } else if (0 == strcmp(argv[1], "art")){
        run_art(argc, argv);
    } else if (0 == strcmp(argv[1], "tag")){
        run_tag(argc, argv);
    } else if (0 == strcmp(argv[1], "compare")){
src/layer.h
@@ -2,6 +2,7 @@
#define BASE_LAYER_H
#include "activations.h"
#include "stddef.h"
struct layer;
typedef struct layer layer;
@@ -157,6 +158,8 @@
    struct layer *input_h_layer;
    struct layer *state_h_layer;
    size_t workspace_size;
    #ifdef GPU
    float *z_gpu;
    float *r_gpu;
@@ -207,6 +210,16 @@
    float * rand_gpu;
    float * squared_gpu;
    float * norms_gpu;
    #ifdef CUDNN
    cudnnTensorDescriptor_t srcTensorDesc, dstTensorDesc;
    cudnnTensorDescriptor_t dsrcTensorDesc, ddstTensorDesc;
    cudnnFilterDescriptor_t filterDesc;
    cudnnFilterDescriptor_t dfilterDesc;
    cudnnConvolutionDescriptor_t convDesc;
    cudnnConvolutionFwdAlgo_t fw_algo;
    cudnnConvolutionBwdDataAlgo_t bd_algo;
    cudnnConvolutionBwdFilterAlgo_t bf_algo;
    #endif
    #endif
};
src/network.h
@@ -11,6 +11,7 @@
} learning_rate_policy;
typedef struct network{
    float *workspace;
    int n;
    int batch;
    int *seen;
@@ -49,6 +50,7 @@
    float *truth;
    float *input;
    float *delta;
    float *workspace;
    int train;
    int index;
    network net;
src/network_kernels.cu
@@ -41,6 +41,7 @@
void forward_network_gpu(network net, network_state state)
{
    state.workspace = net.workspace;
    int i;
    for(i = 0; i < net.n; ++i){
        state.index = i;
@@ -93,6 +94,7 @@
void backward_network_gpu(network net, network_state state)
{
    state.workspace = net.workspace;
    int i;
    float * original_input = state.input;
    float * original_delta = state.delta;
src/parser.c
@@ -524,6 +524,7 @@
    params.batch = net.batch;
    params.time_steps = net.time_steps;
    size_t workspace_size = 0;
    n = n->next;
    int count = 0;
    free_section(s);
@@ -584,6 +585,7 @@
        l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
        option_unused(options);
        net.layers[count] = l;
        if (l.workspace_size > workspace_size) workspace_size = l.workspace_size;
        free_section(s);
        n = n->next;
        ++count;
@@ -597,6 +599,11 @@
    free_list(sections);
    net.outputs = get_network_output_size(net);
    net.output = get_network_output(net);
    if(workspace_size){
#ifdef GPU
        net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
#endif
    }
    return net;
}