Joseph Redmon
2015-04-17 f199fd3b6464e644566d76676c0b5f1824d26c4e
src/detection.c
@@ -1,9 +1,11 @@
#include "network.h"
#include "detection_layer.h"
#include "utils.h"
#include "parser.h"
char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
char *inet_class_names[] = {"bg", "accordion", "airplane", "ant", "antelope", "apple", "armadillo", "artichoke", "axe", "baby bed", "backpack", "bagel", "balance beam", "banana", "band aid", "banjo", "baseball", "basketball", "bathing cap", "beaker", "bear", "bee", "bell pepper", "bench", "bicycle", "binder", "bird", "bookshelf", "bow tie", "bow", "bowl", "brassiere", "burrito", "bus", "butterfly", "camel", "can opener", "car", "cart", "cattle", "cello", "centipede", "chain saw", "chair", "chime", "cocktail shaker", "coffee maker", "computer keyboard", "computer mouse", "corkscrew", "cream", "croquet ball", "crutch", "cucumber", "cup or mug", "diaper", "digital clock", "dishwasher", "dog", "domestic cat", "dragonfly", "drum", "dumbbell", "electric fan", "elephant", "face powder", "fig", "filing cabinet", "flower pot", "flute", "fox", "french horn", "frog", "frying pan", "giant panda", "goldfish", "golf ball", "golfcart", "guacamole", "guitar", "hair dryer", "hair spray", "hamburger", "hammer", "hamster", "harmonica", "harp", "hat with a wide brim", "head cabbage", "helmet", "hippopotamus", "horizontal bar", "horse", "hotdog", "iPod", "isopod", "jellyfish", "koala bear", "ladle", "ladybug", "lamp", "laptop", "lemon", "lion", "lipstick", "lizard", "lobster", "maillot", "maraca", "microphone", "microwave", "milk can", "miniskirt", "monkey", "motorcycle", "mushroom", "nail", "neck brace", "oboe", "orange", "otter", "pencil box", "pencil sharpener", "perfume", "person", "piano", "pineapple", "ping-pong ball", "pitcher", "pizza", "plastic bag", "plate rack", "pomegranate", "popsicle", "porcupine", "power drill", "pretzel", "printer", "puck", "punching bag", "purse", "rabbit", "racket", "ray", "red panda", "refrigerator", "remote control", "rubber eraser", "rugby ball", "ruler", "salt or pepper shaker", "saxophone", "scorpion", "screwdriver", "seal", "sheep", "ski", "skunk", "snail", "snake", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sofa", "spatula", "squirrel", "starfish", "stethoscope", "stove", "strainer", "strawberry", "stretcher", "sunglasses", "swimming trunks", "swine", "syringe", "table", "tape player", "tennis ball", "tick", "tie", "tiger", "toaster", "traffic light", "train", "trombone", "trumpet", "turtle", "tv or monitor", "unicycle", "vacuum", "violin", "volleyball", "waffle iron", "washer", "water bottle", "watercraft", "whale", "wine bottle", "zebra"};
#define AMNT 3
void draw_detection(image im, float *box, int side)
{
@@ -18,7 +20,7 @@
            //printf("%d\n", j);
            //printf("Prob: %f\n", box[j]);
            int class = max_index(box+j, classes);
            if(box[j+class] > .02 || 1){
            if(box[j+class] > .2){
                //int z;
                //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
                printf("%f %s\n", box[j+class], class_names[class]);
@@ -26,13 +28,24 @@
                float green = get_color(1,class,classes);
                float blue = get_color(2,class,classes);
                //float maxheight = distance_from_edge(r, side);
                //float maxwidth  = distance_from_edge(c, side);
                j += classes;
                int d = im.w/side;
                int y = r*d+box[j]*d;
                int x = c*d+box[j+1]*d;
                int h = box[j+2]*im.h;
                int w = box[j+3]*im.w;
                draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2,red,green,blue);
                float y = box[j+0];
                float x = box[j+1];
                x = (x+c)/side;
                y = (y+r)/side;
                float h = box[j+2]; //*maxheight;
                float w = box[j+3]; //*maxwidth;
                h = h*h;
                w = w*w;
                //printf("coords %f %f %f %f\n", x, y, w, h);
                int left  = (x-w/2)*im.w;
                int right = (x+w/2)*im.w;
                int top   = (y-h/2)*im.h;
                int bot   = (y+h/2)*im.h;
                draw_box(im, left, top, right, bot, red, green, blue);
            }
        }
    }
@@ -43,45 +56,57 @@
void train_detection(char *cfgfile, char *weightfile)
{
    srand(time(0));
    data_seed = time(0);
    int imgnet = 0;
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    float avg_loss = 1;
    float avg_loss = -1;
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    detection_layer *layer = get_network_detection_layer(net);
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = 128;
    srand(time(0));
    //srand(23410);
    int i = net.seen/imgs;
    list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
    char **paths = (char **)list_to_array(plist);
    printf("%d\n", plist->size);
    data train, buffer;
    int im_dim = 512;
    int jitter = 64;
    int classes = 21;
    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
    int classes = layer->classes;
    int background = layer->background;
    int side = sqrt(get_detection_layer_locations(*layer));
    char **paths;
    list *plist;
    if (imgnet){
        plist = get_paths("/home/pjreddie/data/imagenet/det.train.list");
    }else{
        plist = get_paths("/home/pjreddie/data/voc/trainall.txt");
        //plist = get_paths("/home/pjreddie/data/coco/trainval.txt");
        //plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt");
    }
    paths = (char **)list_to_array(plist);
    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
    clock_t time;
    while(1){
        i += 1;
        time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
        load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
        load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
        /*
           image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
           draw_detection(im, train.y.vals[0], 7);
           show_image(im, "truth");
           cvWaitKey(0);
         */
/*
 image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
 image copy = copy_image(im);
 draw_detection(copy, train.y.vals[114], 7);
 free_image(copy);
 */
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
        float loss = train_network(net, train);
        net.seen += imgs;
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
        if(i%100==0){
@@ -99,14 +124,22 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    detection_layer *layer = get_network_detection_layer(net);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    list *plist = get_paths("/home/pjreddie/data/voc/val.txt");
    //list *plist = get_paths("/home/pjreddie/data/voc/val.expanded.txt");
    //list *plist = get_paths("/home/pjreddie/data/voc/train.txt");
    char **paths = (char **)list_to_array(plist);
    int num_output = 1225;
    int im_size = 448;
    int classes = 21;
    int classes = layer->classes;
    int nuisance = layer->nuisance;
    int background = (layer->background && !nuisance);
    int num_boxes = sqrt(get_detection_layer_locations(*layer));
    int per_box = 4+classes+background+nuisance;
    int num_output = num_boxes*num_boxes*per_box;
    int m = plist->size;
    int i = 0;
@@ -115,7 +148,7 @@
    fprintf(stderr, "%d\n", m);
    data val, buffer;
    pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, im_size, im_size, &buffer);
    pthread_t load_thread = load_data_thread(paths, num, 0, 0, num_output, net.w, net.h, &buffer);
    clock_t time;
    for(i = 1; i <= splits; ++i){
        time=clock();
@@ -124,35 +157,32 @@
        num = (i+1)*m/splits - i*m/splits;
        char **part = paths+(i*m/splits);
        if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, im_size, im_size, &buffer);
        if(i != splits) load_thread = load_data_thread(part, num, 0, 0, num_output, net.w, net.h, &buffer);
        fprintf(stderr, "%d: Loaded: %lf seconds\n", i, sec(clock()-time));
        matrix pred = network_predict_data(net, val);
        int j, k, class;
        for(j = 0; j < pred.rows; ++j){
            for(k = 0; k < pred.cols; k += classes+4){
                /*
                   int z;
                   for(z = 0; z < 25; ++z) printf("%f, ", pred.vals[j][k+z]);
                   printf("\n");
                 */
                //if (pred.vals[j][k] > .001){
                for(class = 0; class < classes-1; ++class){
                    int index = (k)/(classes+4);
                    int r = index/7;
                    int c = index%7;
                    float y = (r + pred.vals[j][k+0+classes])/7.;
                    float x = (c + pred.vals[j][k+1+classes])/7.;
                    float h = pred.vals[j][k+2+classes];
                    float w = pred.vals[j][k+3+classes];
                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
            for(k = 0; k < pred.cols; k += per_box){
                float scale = 1.;
                int index = k/per_box;
                int row = index / num_boxes;
                int col = index % num_boxes;
                if (nuisance) scale = 1.-pred.vals[j][k];
                for (class = 0; class < classes; ++class){
                    int ci = k+classes+background+nuisance;
                    float y = (pred.vals[j][ci + 0] + row)/num_boxes;
                    float x = (pred.vals[j][ci + 1] + col)/num_boxes;
                    float h = pred.vals[j][ci + 2]; //* distance_from_edge(row, num_boxes);
                    h = h*h;
                    float w = pred.vals[j][ci + 3]; //* distance_from_edge(col, num_boxes);
                    w = w*w;
                    float prob = scale*pred.vals[j][k+class+background+nuisance];
                    if(prob < .001) continue;
                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, prob, y, x, h, w);
                }
                //}
            }
        }
        time=clock();
        free_data(val);
    }
@@ -173,8 +203,6 @@
        fgets(filename, 256, stdin);
        strtok(filename, "\n");
        image im = load_image_color(filename, im_size, im_size);
        translate_image(im, -128);
        scale_image(im, 1/128.);
        printf("%d %d %d\n", im.h, im.w, im.c);
        float *X = im.data;
        time=clock();