Joseph Redmon
2016-10-26 352ae7e65b6a74bcd768aa88b866a44c713284c8
ADAM
11 files modified
126 ■■■■ changed files
src/blas.h 1 ●●●● patch | view | raw | blame | history
src/blas_kernels.cu 15 ●●●●● patch | view | raw | blame | history
src/classifier.c 45 ●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 15 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.c 10 ●●●● patch | view | raw | blame | history
src/convolutional_layer.h 2 ●●● patch | view | raw | blame | history
src/crnn_layer.c 6 ●●●● patch | view | raw | blame | history
src/layer.h 8 ●●●●● patch | view | raw | blame | history
src/network.h 5 ●●●●● patch | view | raw | blame | history
src/network_kernels.cu 3 ●●●● patch | view | raw | blame | history
src/parser.c 16 ●●●●● patch | view | raw | blame | history
src/blas.h
@@ -78,6 +78,7 @@
void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
#endif
#endif
src/blas_kernels.cu
@@ -140,6 +140,21 @@
}
__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (index >= N) return;
    x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
    //if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
}
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
    adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t);
    check_error(cudaPeekAtLastError());
}
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
src/classifier.c
@@ -41,7 +41,7 @@
    return options;
}
void hierarchy_predictions(float *predictions, int n, tree *hier)
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
    int j;
    for(j = 0; j < n; ++j){
@@ -50,10 +50,12 @@
            predictions[j] *= predictions[parent]; 
        }
    }
    if(only_leaves){
    for(j = 0; j < n; ++j){
        if(!hier->leaf[j]) predictions[j] = 0;
    }
}
}
float *get_regression_values(char **labels, int n)
{
@@ -410,7 +412,7 @@
        float *pred = calloc(classes, sizeof(float));
        for(j = 0; j < 10; ++j){
            float *p = network_predict(net, images[j].data);
            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
            axpy_cpu(classes, 1, p, 1, pred, 1);
            free_image(images[j]);
        }
@@ -471,7 +473,7 @@
        //show_image(crop, "cropped");
        //cvWaitKey(0);
        float *pred = network_predict(net, resized.data);
        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy, 1);
        free_image(im);
        free_image(resized);
@@ -486,6 +488,26 @@
    }
}
void change_leaves(tree *t, char *leaf_list)
{
    list *llist = get_paths(leaf_list);
    char **leaves = (char **)list_to_array(llist);
    int n = llist->size;
    int i,j;
    int found = 0;
    for(i = 0; i < t->n; ++i){
        t->leaf[i] = 0;
        for(j = 0; j < n; ++j){
            if (0==strcmp(t->name[i], leaves[j])){
                t->leaf[i] = 1;
                ++found;
                break;
            }
        }
    }
    fprintf(stderr, "Found %d leaves.\n", found);
}
void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
{
@@ -500,6 +522,8 @@
    list *options = read_data_cfg(datacfg);
    char *label_list = option_find_str(options, "labels", "data/labels.list");
    char *leaf_list = option_find_str(options, "leaves", 0);
    if(leaf_list) change_leaves(net.hierarchy, leaf_list);
    char *valid_list = option_find_str(options, "valid", "data/train.list");
    int classes = option_find_int(options, "classes", 2);
    int topk = option_find_int(options, "top", 1);
@@ -531,7 +555,7 @@
        //show_image(crop, "cropped");
        //cvWaitKey(0);
        float *pred = network_predict(net, crop.data);
        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
        if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy, 1);
        if(resized.data != im.data) free_image(resized);
        free_image(im);
@@ -592,7 +616,7 @@
            image r = resize_min(im, scales[j]);
            resize_network(&net, r.w, r.h);
            float *p = network_predict(net, r.data);
            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
            if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
            axpy_cpu(classes, 1, p, 1, pred, 1);
            flip_image(r);
            p = network_predict(net, r.data);
@@ -692,7 +716,7 @@
    }
}
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename)
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
@@ -705,7 +729,7 @@
    char *name_list = option_find_str(options, "names", 0);
    if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
    int top = option_find_int(options, "top", 1);
    if(top == 0) top = option_find_int(options, "top", 1);
    int i = 0;
    char **names = get_labels(name_list);
@@ -732,7 +756,7 @@
        float *X = r.data;
        time=clock();
        float *predictions = network_predict(net, X);
        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0);
        top_k(predictions, net.outputs, top, indexes);
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        for(i = 0; i < top; ++i){
@@ -1113,7 +1137,7 @@
        show_image(in, "Classifier");
        float *predictions = network_predict(net, in_s.data);
        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
        if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 1);
        top_predictions(net, top, indexes);
        printf("\033[2J");
@@ -1165,6 +1189,7 @@
    }
    int cam_index = find_int_arg(argc, argv, "-c", 0);
    int top = find_int_arg(argc, argv, "-t", 0);
    int clear = find_arg(argc, argv, "-clear");
    char *data = argv[3];
    char *cfg = argv[4];
@@ -1172,7 +1197,7 @@
    char *filename = (argc > 6) ? argv[6]: 0;
    char *layer_s = (argc > 7) ? argv[7]: 0;
    int layer = layer_s ? atoi(layer_s) : -1;
    if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
    if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
    else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
    else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
    else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);
src/convolutional_kernels.cu
@@ -233,7 +233,6 @@
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    int size = layer.size*layer.size*layer.c*layer.n;
    axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
@@ -242,9 +241,23 @@
        scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
    }
    if(layer.adam){
        scal_ongpu(size, layer.B1, layer.m_gpu, 1);
        scal_ongpu(size, layer.B2, layer.v_gpu, 1);
        axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
        axpy_ongpu(size, -(1-layer.B1), layer.weight_updates_gpu, 1, layer.m_gpu, 1);
        mul_ongpu(size, layer.weight_updates_gpu, 1, layer.weight_updates_gpu, 1);
        axpy_ongpu(size, (1-layer.B2), layer.weight_updates_gpu, 1, layer.v_gpu, 1);
        adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
        fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
    }else{
    axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
    axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
    scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
}
}
src/convolutional_layer.c
@@ -171,7 +171,7 @@
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor)
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
{
    int i;
    convolutional_layer l = {0};
@@ -242,6 +242,12 @@
    l.update_gpu = update_convolutional_layer_gpu;
    if(gpu_index >= 0){
        if (adam) {
            l.adam = 1;
            l.m_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
            l.v_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
        }
        l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
        l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
@@ -312,7 +318,7 @@
void test_convolutional_layer()
{
    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
    convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0);
    l.batch_normalize = 1;
    float data[] = {1,1,1,1,1,
        1,1,1,1,1,
src/convolutional_layer.h
@@ -24,7 +24,7 @@
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary, int xnor);
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);
src/crnn_layer.c
@@ -48,17 +48,17 @@
    l.input_layer = malloc(sizeof(layer));
    fprintf(stderr, "\t\t");
    *(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
    *(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0, 0);
    l.input_layer->batch = batch;
    l.self_layer = malloc(sizeof(layer));
    fprintf(stderr, "\t\t");
    *(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
    *(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1,  activation, batch_normalize, 0, 0, 0);
    l.self_layer->batch = batch;
    l.output_layer = malloc(sizeof(layer));
    fprintf(stderr, "\t\t");
    *(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1,  activation, batch_normalize, 0, 0);
    *(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1,  activation, batch_normalize, 0, 0, 0);
    l.output_layer->batch = batch;
    l.output = l.output_layer->output;
src/layer.h
@@ -94,6 +94,14 @@
    int reorg;
    int log;
    int adam;
    float B1;
    float B2;
    float eps;
    float *m_gpu;
    float *v_gpu;
    int t;
    tree *softmax_tree;
    float alpha;
src/network.h
@@ -37,6 +37,11 @@
    int num_steps;
    int burn_in;
    int adam;
    float B1;
    float B2;
    float eps;
    int inputs;
    int h, w, c;
    int max_crop;
src/network_kernels.cu
@@ -83,6 +83,7 @@
    float rate = get_current_rate(net);
    for(i = 0; i < net.n; ++i){
        layer l = net.layers[i];
        l.t = get_current_batch(net);
        if(l.update_gpu){
            l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
        }
@@ -134,7 +135,6 @@
    free(ptr);
    cuda_set_device(args.net.gpu_index);
    *args.err = train_network(args.net, args.d);
    printf("%d\n", args.net.gpu_index);
    return 0;
}
@@ -177,6 +177,7 @@
{
    int update_batch = net.batch*net.subdivisions;
    float rate = get_current_rate(net);
    l.t = get_current_batch(net);
    if(l.update_gpu){
        l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
    }
src/parser.c
@@ -111,6 +111,7 @@
    int c;
    int index;
    int time_steps;
    network net;
} size_params;
local_layer parse_local(list *options, size_params params)
@@ -156,9 +157,14 @@
    int binary = option_find_int_quiet(options, "binary", 0);
    int xnor = option_find_int_quiet(options, "xnor", 0);
    convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor);
    convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
    layer.flipped = option_find_int_quiet(options, "flipped", 0);
    layer.dot = option_find_float_quiet(options, "dot", 0);
    if(params.net.adam){
        layer.B1 = params.net.B1;
        layer.B2 = params.net.B2;
        layer.eps = params.net.eps;
    }
    return layer;
}
@@ -482,6 +488,13 @@
    net->batch *= net->time_steps;
    net->subdivisions = subdivs;
    net->adam = option_find_int_quiet(options, "adam", 0);
    if(net->adam){
        net->B1 = option_find_float(options, "B1", .9);
        net->B2 = option_find_float(options, "B2", .999);
        net->eps = option_find_float(options, "eps", .000001);
    }
    net->h = option_find_int_quiet(options, "height",0);
    net->w = option_find_int_quiet(options, "width",0);
    net->c = option_find_int_quiet(options, "channels",0);
@@ -564,6 +577,7 @@
    params.inputs = net.inputs;
    params.batch = net.batch;
    params.time_steps = net.time_steps;
    params.net = net;
    size_t workspace_size = 0;
    n = n->next;