Joseph Redmon
2015-10-09 c40cdeb4021fc1a638969563972f13c9f5e90d74
lots of comparator stuff
18 files modified
313 ■■■■ changed files
Makefile 2 ●●● patch | view | raw | blame | history
cfg/darknet.cfg 1 ●●●● patch | view | raw | blame | history
src/coco.c 1 ●●●● patch | view | raw | blame | history
src/compare.c 104 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c patch | view | raw | blame | history
src/darknet.c 3 ●●●●● patch | view | raw | blame | history
src/data.c 8 ●●●● patch | view | raw | blame | history
src/data.h 1 ●●●● patch | view | raw | blame | history
src/dice.c 2 ●●● patch | view | raw | blame | history
src/imagenet.c 2 ●●● patch | view | raw | blame | history
src/layer.h 4 ●●●● patch | view | raw | blame | history
src/network.c 4 ●●●● patch | view | raw | blame | history
src/network.h 2 ●●● patch | view | raw | blame | history
src/option_list.c 18 ●●●●● patch | view | raw | blame | history
src/option_list.h 1 ●●●● patch | view | raw | blame | history
src/parser.c 23 ●●●● patch | view | raw | blame | history
src/region_layer.c 38 ●●●●● patch | view | raw | blame | history
src/swag.c 99 ●●●●● patch | view | raw | blame | history
Makefile
@@ -34,7 +34,7 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif
cfg/darknet.cfg
@@ -104,6 +104,7 @@
activation=leaky
[softmax]
groups=1
[cost]
type=sse
src/coco.c
@@ -135,6 +135,7 @@
        }
    }
}
void get_boxes(float *predictions, int n, int num_boxes, int per_box, box *boxes)
{
    int i,j;
src/compare.c
@@ -150,17 +150,20 @@
    network net;
    char *filename;
    int class;
    int classes;
    float elo;
    float *elos;
} sortable_bbox;
int total_compares = 0;
int current_class = 0;
int elo_comparator(const void*a, const void *b)
{
    sortable_bbox box1 = *(sortable_bbox*)a;
    sortable_bbox box2 = *(sortable_bbox*)b;
    if(box1.elo == box2.elo) return 0;
    if(box1.elo >  box2.elo) return -1;
    if(box1.elos[current_class] == box2.elos[current_class]) return 0;
    if(box1.elos[current_class] >  box2.elos[current_class]) return -1;
    return 1;
}
@@ -188,16 +191,38 @@
    return -1;
}
void bbox_fight(sortable_bbox *a, sortable_bbox *b)
void bbox_update(sortable_bbox *a, sortable_bbox *b, int class, int result)
{
    int k = 32;
    int result = bbox_comparator(a,b);
    float EA = 1./(1+pow(10, (b->elo - a->elo)/400.));
    float EB = 1./(1+pow(10, (a->elo - b->elo)/400.));
    float SA = 1.*(result > 0);
    float SB = 1.*(result < 0);
    a->elo = a->elo + k*(SA - EA);
    b->elo = b->elo + k*(SB - EB);
    float EA = 1./(1+pow(10, (b->elos[class] - a->elos[class])/400.));
    float EB = 1./(1+pow(10, (a->elos[class] - b->elos[class])/400.));
    float SA = result ? 1 : 0;
    float SB = result ? 0 : 1;
    a->elos[class] += k*(SA - EA);
    b->elos[class] += k*(SB - EB);
}
void bbox_fight(network net, sortable_bbox *a, sortable_bbox *b, int classes, int class)
{
    image im1 = load_image_color(a->filename, net.w, net.h);
    image im2 = load_image_color(b->filename, net.w, net.h);
    float *X  = calloc(net.w*net.h*net.c, sizeof(float));
    memcpy(X,                   im1.data, im1.w*im1.h*im1.c*sizeof(float));
    memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
    float *predictions = network_predict(net, X);
    ++total_compares;
    int i;
    for(i = 0; i < classes; ++i){
        if(class < 0 || class == i){
            int result = predictions[i*2] > predictions[i*2+1];
            bbox_update(a, b, i, result);
        }
    }
    free_image(im1);
    free_image(im2);
    free(X);
}
void SortMaster3000(char *filename, char *weightfile)
@@ -233,7 +258,8 @@
void BattleRoyaleWithCheese(char *filename, char *weightfile)
{
    int i = 0;
    int classes = 20;
    int i,j;
    network net = parse_network_cfg(filename);
    if(weightfile){
        load_weights(&net, weightfile);
@@ -241,47 +267,67 @@
    srand(time(0));
    set_batch_network(&net, 1);
    //list *plist = get_paths("data/compare.sort.list");
    list *plist = get_paths("data/compare.cat.list");
    list *plist = get_paths("data/compare.sort.list");
    //list *plist = get_paths("data/compare.small.list");
    //list *plist = get_paths("data/compare.cat.list");
    //list *plist = get_paths("data/compare.val.old");
    char **paths = (char **)list_to_array(plist);
    int N = plist->size;
    int total = N;
    free_list(plist);
    sortable_bbox *boxes = calloc(N, sizeof(sortable_bbox));
    printf("Battling %d boxes...\n", N);
    for(i = 0; i < N; ++i){
        boxes[i].filename = paths[i];
        boxes[i].net = net;
        boxes[i].class = 7;
        boxes[i].elo = 1500;
        boxes[i].classes = classes;
        boxes[i].elos = calloc(classes, sizeof(float));;
        for(j = 0; j < classes; ++j){
            boxes[i].elos[j] = 1500;
        }
    }
    int round;
    clock_t time=clock();
    for(round = 1; round <= 500; ++round){
    for(round = 1; round <= 4; ++round){
        clock_t round_time=clock();
        printf("Round: %d\n", round);
        qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
        sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
        shuffle(boxes, N, sizeof(sortable_bbox));
        for(i = 0; i < N/2; ++i){
            bbox_fight(boxes+i*2, boxes+i*2+1);
        }
        if(round >= 4 && 0){
            qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
            if(round == 4){
                N = N/2;
            }else{
                N = (N*9/10)/2*2;
            }
            bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, -1);
        }
        printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
    }
    int class;
    for (class = 0; class < classes; ++class){
        N = total;
        current_class = class;
    qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
    FILE *outfp = fopen("results/battle.log", "w");
        N /= 2;
        for(round = 1; round <= 20; ++round){
            clock_t round_time=clock();
            printf("Round: %d\n", round);
            sorta_shuffle(boxes, N, sizeof(sortable_bbox), 10);
            for(i = 0; i < N/2; ++i){
                bbox_fight(net, boxes+i*2, boxes+i*2+1, classes, class);
            }
            qsort(boxes, N, sizeof(sortable_bbox), elo_comparator);
            N = (N*9/10)/2*2;
            printf("Round: %f secs, %d remaining\n", sec(clock()-round_time), N);
        }
        char buff[256];
        sprintf(buff, "results/battle_%d.log", class);
        FILE *outfp = fopen(buff, "w");
    for(i = 0; i < N; ++i){
        fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elo);
            fprintf(outfp, "%s %f\n", boxes[i].filename, boxes[i].elos[class]);
    }
    fclose(outfp);
    }
    printf("Tournament in %d compares, %f secs\n", total_compares, sec(clock()-time));
}
src/convolutional_layer.c
src/darknet.c
@@ -20,6 +20,7 @@
extern void run_nightmare(int argc, char **argv);
extern void run_dice(int argc, char **argv);
extern void run_compare(int argc, char **argv);
extern void run_classifier(int argc, char **argv);
void change_rate(char *filename, float scale, float add)
{
@@ -183,6 +184,8 @@
        run_swag(argc, argv);
    } else if (0 == strcmp(argv[1], "coco")){
        run_coco(argc, argv);
    } else if (0 == strcmp(argv[1], "classifier")){
        run_classifier(argc, argv);
    } else if (0 == strcmp(argv[1], "compare")){
        run_compare(argc, argv);
    } else if (0 == strcmp(argv[1], "dice")){
src/data.c
@@ -366,7 +366,7 @@
    }
}
data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes)
data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes, float jitter)
{
    char **random_paths = get_random_paths(paths, n, m);
    int i;
@@ -385,8 +385,8 @@
        int oh = orig.h;
        int ow = orig.w;
        int dw = ow/10;
        int dh = oh/10;
        int dw = (ow*jitter);
        int dh = (oh*jitter);
        int pleft  = (rand_uniform() * 2*dw - dw);
        int pright = (rand_uniform() * 2*dw - dw);
@@ -556,7 +556,7 @@
    } else if (a.type == WRITING_DATA){
        *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
    } else if (a.type == REGION_DATA){
        *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes);
        *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
    } else if (a.type == COMPARE_DATA){
        *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
    } else if (a.type == IMAGE_DATA){
src/data.h
@@ -44,6 +44,7 @@
    int num_boxes;
    int classes;
    int background;
    float jitter;
    data *d;
    image *im;
    image *resized;
src/dice.c
@@ -61,7 +61,7 @@
    free_list(plist);
    data val = load_data(paths, m, 0, labels, 6, net.w, net.h);
    float *acc = network_accuracies(net, val);
    float *acc = network_accuracies(net, val, 2);
    printf("Validation Accuracy: %f, %d images\n", acc[0], m);
    free_data(val);
}
src/imagenet.c
@@ -133,7 +133,7 @@
        printf("Loaded: %d images in %lf seconds\n", val.X.rows, sec(clock()-time));
        time=clock();
        float *acc = network_accuracies(net, val);
        float *acc = network_accuracies(net, val, 5);
        avg_acc += acc[0];
        avg_top5 += acc[1];
        printf("%d: top1: %f, top5: %f, %lf seconds, %d images\n", i, avg_acc/i, avg_top5/i, sec(clock()-time), val.X.rows);
src/layer.h
@@ -29,6 +29,9 @@
    COST_TYPE cost_type;
    int batch;
    int forced;
    int object_logistic;
    int class_logistic;
    int coord_logistic;
    int inputs;
    int outputs;
    int truths;
@@ -45,6 +48,7 @@
    int sqrt;
    int flip;
    float angle;
    float jitter;
    float saturation;
    float exposure;
    int softmax;
src/network.c
@@ -540,12 +540,12 @@
    return acc;
}
float *network_accuracies(network net, data d)
float *network_accuracies(network net, data d, int n)
{
    static float acc[2];
    matrix guess = network_predict_data(net, d);
    acc[0] = matrix_topk_accuracy(d.y, guess,1);
    acc[1] = matrix_topk_accuracy(d.y, guess,5);
    acc[1] = matrix_topk_accuracy(d.y, guess, n);
    free_matrix(guess);
    return acc;
}
src/network.h
@@ -70,7 +70,7 @@
matrix network_predict_data(network net, data test);
float *network_predict(network net, float *input);
float network_accuracy(network net, data d);
float *network_accuracies(network net, data d);
float *network_accuracies(network net, data d, int n);
float network_accuracy_multi(network net, data d, int n);
void top_predictions(network net, int n, int *index);
float *get_network_output(network net);
src/option_list.c
@@ -3,6 +3,24 @@
#include <string.h>
#include "option_list.h"
int read_option(char *s, list *options)
{
    size_t i;
    size_t len = strlen(s);
    char *val = 0;
    for(i = 0; i < len; ++i){
        if(s[i] == '='){
            s[i] = '\0';
            val = s+i+1;
            break;
        }
    }
    if(i == len-1) return 0;
    char *key = s;
    option_insert(options, key, val);
    return 1;
}
void option_insert(list *l, char *key, char *val)
{
    kvp *p = malloc(sizeof(kvp));
src/option_list.h
@@ -9,6 +9,7 @@
} kvp;
int read_option(char *s, list *options);
void option_insert(list *l, char *key, char *val);
char *option_find(list *l, char *key);
char *option_find_str(list *l, char *key, char *def);
src/parser.c
@@ -186,11 +186,16 @@
    layer.softmax = option_find_int(options, "softmax", 0);
    layer.sqrt = option_find_int(options, "sqrt", 0);
    layer.object_logistic = option_find_int(options, "object_logistic", 0);
    layer.class_logistic = option_find_int(options, "class_logistic", 0);
    layer.coord_logistic = option_find_int(options, "coord_logistic", 0);
    layer.coord_scale = option_find_float(options, "coord_scale", 1);
    layer.forced = option_find_int(options, "forced", 0);
    layer.object_scale = option_find_float(options, "object_scale", 1);
    layer.noobject_scale = option_find_float(options, "noobject_scale", 1);
    layer.class_scale = option_find_float(options, "class_scale", 1);
    layer.jitter = option_find_float(options, "jitter", .1);
    return layer;
}
@@ -532,24 +537,6 @@
    return (strcmp(s->type, "[route]")==0);
}
int read_option(char *s, list *options)
{
    size_t i;
    size_t len = strlen(s);
    char *val = 0;
    for(i = 0; i < len; ++i){
        if(s[i] == '='){
            s[i] = '\0';
            val = s+i+1;
            break;
        }
    }
    if(i == len-1) return 0;
    char *key = s;
    option_insert(options, key, val);
    return 1;
}
list *read_cfg(char *filename)
{
    FILE *file = fopen(filename, "r");
src/region_layer.c
@@ -57,6 +57,28 @@
            activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
        }
    }
    if (l.object_logistic) {
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            int p_index = index + locations*l.classes;
            activate_array(l.output + p_index, locations*l.n, LOGISTIC);
        }
    }
    if (l.coord_logistic) {
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            int coord_index = index + locations*(l.classes + l.n);
            activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC);
        }
    }
    if (l.class_logistic) {
        for(b = 0; b < l.batch; ++b){
            int class_index = b*l.inputs;
            activate_array(l.output + class_index, locations*l.classes, LOGISTIC);
        }
    }
    if(state.train){
        float avg_iou = 0;
@@ -85,7 +107,6 @@
                float best_rmse = 20;
                if (!is_obj){
                    //printf(".");
                    continue;
                }
@@ -113,6 +134,7 @@
                    }
                    float iou  = box_iou(out, truth);
                    //iou = 0;
                    float rmse = box_rmse(out, truth);
                    if(best_iou > 0 || iou > 0){
                        if(iou > best_iou){
@@ -175,6 +197,20 @@
                gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords), 
                        LOGISTIC, l.delta + index + locations*l.classes);
            }
            if (l.object_logistic) {
                int p_index = index + locations*l.classes;
                gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index);
            }
            if (l.class_logistic) {
                int class_index = index;
                gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index);
            }
            if (l.coord_logistic) {
                    int coord_index = index + locations*(l.classes + l.n);
                gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index);
            }
            //printf("\n");
        }
        printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
src/swag.c
@@ -73,6 +73,7 @@
    int side = l.side;
    int classes = l.classes;
    float jitter = l.jitter;
    list *plist = get_paths(train_images);
    //int N = plist->size;
@@ -85,6 +86,7 @@
    args.n = imgs;
    args.m = plist->size;
    args.classes = classes;
    args.jitter = jitter;
    args.num_boxes = side;
    args.d = &buffer;
    args.type = REGION_DATA;
@@ -127,7 +129,7 @@
    save_weights(net, buff);
}
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes)
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness)
{
    int i,j,n;
    //int per_cell = 5*num+classes;
@@ -148,6 +150,9 @@
                float prob = scale*predictions[class_index+j];
                probs[index][j] = (prob > thresh) ? prob : 0;
            }
            if(only_objectness){
                probs[index][0] = scale;
            }
        }
    }
}
@@ -250,7 +255,7 @@
            float *predictions = network_predict(net, X);
            int w = val[t].w;
            int h = val[t].h;
            convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes);
            convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0);
            if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh);
            print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
            free(id);
@@ -261,6 +266,95 @@
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
void validate_swag_recall(char *cfgfile, char *weightfile)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    char *base = "results/comp4_det_test_";
    list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
    char **paths = (char **)list_to_array(plist);
    layer l = net.layers[net.n-1];
    int classes = l.classes;
    int square = l.sqrt;
    int side = l.side;
    int j, k;
    FILE **fps = calloc(classes, sizeof(FILE *));
    for(j = 0; j < classes; ++j){
        char buff[1024];
        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
        fps[j] = fopen(buff, "w");
    }
    box *boxes = calloc(side*side*l.n, sizeof(box));
    float **probs = calloc(side*side*l.n, sizeof(float *));
    for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
    int m = plist->size;
    int i=0;
    float thresh = .001;
    int nms = 0;
    float iou_thresh = .5;
    float nms_thresh = .5;
    int total = 0;
    int correct = 0;
    int proposals = 0;
    float avg_iou = 0;
    for(i = 0; i < m; ++i){
        char *path = paths[i];
        image orig = load_image_color(path, 0, 0);
        image sized = resize_image(orig, net.w, net.h);
        char *id = basecfg(path);
        float *predictions = network_predict(net, sized.data);
        int w = orig.w;
        int h = orig.h;
        convert_swag_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
        if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
        char *labelpath = find_replace(path, "images", "labels");
        labelpath = find_replace(labelpath, "JPEGImages", "labels");
        labelpath = find_replace(labelpath, ".jpg", ".txt");
        labelpath = find_replace(labelpath, ".JPEG", ".txt");
        int num_labels = 0;
        box_label *truth = read_boxes(labelpath, &num_labels);
        for(k = 0; k < side*side*l.n; ++k){
            if(probs[k][0] > thresh){
                ++proposals;
            }
        }
        for (j = 0; j < num_labels; ++j) {
            ++total;
            box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
            float best_iou = 0;
            for(k = 0; k < side*side*l.n; ++k){
                float iou = box_iou(boxes[k], t);
                if(probs[k][0] > thresh && iou > best_iou){
                    best_iou = iou;
                }
            }
            avg_iou += best_iou;
            if(best_iou > iou_thresh){
                ++correct;
            }
        }
        fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
        free(id);
        free_image(orig);
        free_image(sized);
    }
}
void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh)
{
@@ -316,4 +410,5 @@
    if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh);
    else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
    else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights);
    else if(0==strcmp(argv[2], "recall")) validate_swag_recall(cfg, weights);
}