Joseph Redmon
2016-11-11 9a01e6ccb7a74ff77e99060cf18acd6cfdb74b8e
:fire: crush. crush. admit. :fire:
4 files modified
498 ■■■■■ changed files
src/classifier.c 203 ●●●● patch | view | raw | blame | history
src/darknet.c 55 ●●●●● patch | view | raw | blame | history
src/detector.c 165 ●●●● patch | view | raw | blame | history
src/region_layer.c 75 ●●●●● patch | view | raw | blame | history
src/classifier.c
@@ -25,9 +25,8 @@
    return v;
}
void train_classifier_multi(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
#ifdef GPU
    int i;
    float avg_loss = -1;
@@ -40,7 +39,9 @@
    int seed = rand();
    for(i = 0; i < ngpus; ++i){
        srand(seed);
#ifdef GPU
        cuda_set_device(gpus[i]);
#endif
        nets[i] = parse_network_cfg(cfgfile);
        if(weightfile){
            load_weights(&nets[i], weightfile);
@@ -107,7 +108,16 @@
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
        float loss = train_networks(nets, ngpus, train, 4);
        float loss = 0;
#ifdef GPU
        if(ngpus == 1){
            loss = train_network(net, train);
        } else {
            loss = train_networks(nets, ngpus, train, 4);
        }
#else
        loss = train_network(net, train);
#endif
        if(avg_loss == -1) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
@@ -133,117 +143,118 @@
    free_ptrs((void**)paths, plist->size);
    free_list(plist);
    free(base);
#endif
}
void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
{
    srand(time(0));
    float avg_loss = -1;
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    if(clear) *net.seen = 0;
/*
   void train_classifier(char *datacfg, char *cfgfile, char *weightfile, int clear)
   {
   srand(time(0));
   float avg_loss = -1;
   char *base = basecfg(cfgfile);
   printf("%s\n", base);
   network net = parse_network_cfg(cfgfile);
   if(weightfile){
   load_weights(&net, weightfile);
   }
   if(clear) *net.seen = 0;
    int imgs = net.batch * net.subdivisions;
   int imgs = net.batch * net.subdivisions;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    list *options = read_data_cfg(datacfg);
   printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
   list *options = read_data_cfg(datacfg);
    char *backup_directory = option_find_str(options, "backup", "/backup/");
    char *label_list = option_find_str(options, "labels", "data/labels.list");
    char *train_list = option_find_str(options, "train", "data/train.list");
    int classes = option_find_int(options, "classes", 2);
   char *backup_directory = option_find_str(options, "backup", "/backup/");
   char *label_list = option_find_str(options, "labels", "data/labels.list");
   char *train_list = option_find_str(options, "train", "data/train.list");
   int classes = option_find_int(options, "classes", 2);
    char **labels = get_labels(label_list);
    list *plist = get_paths(train_list);
    char **paths = (char **)list_to_array(plist);
    printf("%d\n", plist->size);
    int N = plist->size;
    clock_t time;
   char **labels = get_labels(label_list);
   list *plist = get_paths(train_list);
   char **paths = (char **)list_to_array(plist);
   printf("%d\n", plist->size);
   int N = plist->size;
   clock_t time;
    load_args args = {0};
    args.w = net.w;
    args.h = net.h;
    args.threads = 8;
   load_args args = {0};
   args.w = net.w;
   args.h = net.h;
   args.threads = 8;
    args.min = net.min_crop;
    args.max = net.max_crop;
    args.angle = net.angle;
    args.aspect = net.aspect;
    args.exposure = net.exposure;
    args.saturation = net.saturation;
    args.hue = net.hue;
    args.size = net.w;
    args.hierarchy = net.hierarchy;
   args.min = net.min_crop;
   args.max = net.max_crop;
   args.angle = net.angle;
   args.aspect = net.aspect;
   args.exposure = net.exposure;
   args.saturation = net.saturation;
   args.hue = net.hue;
   args.size = net.w;
   args.hierarchy = net.hierarchy;
    args.paths = paths;
    args.classes = classes;
    args.n = imgs;
    args.m = N;
    args.labels = labels;
    args.type = CLASSIFICATION_DATA;
   args.paths = paths;
   args.classes = classes;
   args.n = imgs;
   args.m = N;
   args.labels = labels;
   args.type = CLASSIFICATION_DATA;
    data train;
    data buffer;
    pthread_t load_thread;
    args.d = &buffer;
    load_thread = load_data(args);
   data train;
   data buffer;
   pthread_t load_thread;
   args.d = &buffer;
   load_thread = load_data(args);
    int epoch = (*net.seen)/N;
    while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
        time=clock();
   int epoch = (*net.seen)/N;
   while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
   time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
        load_thread = load_data(args);
   pthread_join(load_thread, 0);
   train = buffer;
   load_thread = load_data(args);
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
   printf("Loaded: %lf seconds\n", sec(clock()-time));
   time=clock();
#ifdef OPENCV
        if(0){
            int u;
            for(u = 0; u < imgs; ++u){
                image im = float_to_image(net.w, net.h, 3, train.X.vals[u]);
                show_image(im, "loaded");
                cvWaitKey(0);
            }
        }
if(0){
int u;
for(u = 0; u < imgs; ++u){
    image im = float_to_image(net.w, net.h, 3, train.X.vals[u]);
    show_image(im, "loaded");
    cvWaitKey(0);
}
}
#endif
        float loss = train_network(net, train);
        free_data(train);
float loss = train_network(net, train);
free_data(train);
        if(avg_loss == -1) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
        if(*net.seen/N > epoch){
            epoch = *net.seen/N;
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
            save_weights(net, buff);
        }
        if(get_current_batch(net)%100 == 0){
            char buff[256];
            sprintf(buff, "%s/%s.backup",backup_directory,base);
            save_weights(net, buff);
        }
    }
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
if(*net.seen/N > epoch){
    epoch = *net.seen/N;
    char buff[256];
    sprintf(buff, "%s/%s.weights", backup_directory, base);
    sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
    save_weights(net, buff);
    free_network(net);
    free_ptrs((void**)labels, classes);
    free_ptrs((void**)paths, plist->size);
    free_list(plist);
    free(base);
}
if(get_current_batch(net)%100 == 0){
    char buff[256];
    sprintf(buff, "%s/%s.backup",backup_directory,base);
    save_weights(net, buff);
}
}
char buff[256];
sprintf(buff, "%s/%s.weights", backup_directory, base);
save_weights(net, buff);
free_network(net);
free_ptrs((void**)labels, classes);
free_ptrs((void**)paths, plist->size);
free_list(plist);
free(base);
}
*/
void validate_classifier_crop(char *datacfg, char *filename, char *weightfile)
{
@@ -1108,6 +1119,7 @@
    char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
    int *gpus = 0;
    int gpu = 0;
    int ngpus = 0;
    if(gpu_list){
        printf("%s\n", gpu_list);
@@ -1122,6 +1134,10 @@
            gpus[i] = atoi(gpu_list);
            gpu_list = strchr(gpu_list, ',')+1;
        }
    } else {
        gpu = gpu_index;
        gpus = &gpu;
        ngpus = 1;
    }
    int cam_index = find_int_arg(argc, argv, "-c", 0);
@@ -1135,8 +1151,7 @@
    int layer = layer_s ? atoi(layer_s) : -1;
    if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
    else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
    else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
    else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);
    else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, gpus, ngpus, clear);
    else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
    else if(0==strcmp(argv[2], "gun")) gun_classifier(data, cfg, weights, cam_index, filename);
    else if(0==strcmp(argv[2], "threat")) threat_classifier(data, cfg, weights, cam_index, filename);
src/darknet.c
@@ -30,20 +30,6 @@
extern void run_art(int argc, char **argv);
extern void run_super(int argc, char **argv);
void change_rate(char *filename, float scale, float add)
{
    // Ready for some weird shit??
    FILE *fp = fopen(filename, "r+b");
    if(!fp) file_error(filename);
    float rate = 0;
    fread(&rate, sizeof(float), 1, fp);
    printf("Scaling learning rate from %f to %f\n", rate, rate*scale+add);
    rate = rate*scale + add;
    fseek(fp, 0, SEEK_SET);
    fwrite(&rate, sizeof(float), 1, fp);
    fclose(fp);
}
void average(int argc, char *argv[])
{
    char *cfgfile = argv[2];
@@ -67,6 +53,11 @@
                int num = l.n*l.c*l.size*l.size;
                axpy_cpu(l.n, 1, l.biases, 1, out.biases, 1);
                axpy_cpu(num, 1, l.weights, 1, out.weights, 1);
                if(l.batch_normalize){
                    axpy_cpu(l.n, 1, l.scales, 1, out.scales, 1);
                    axpy_cpu(l.n, 1, l.rolling_mean, 1, out.rolling_mean, 1);
                    axpy_cpu(l.n, 1, l.rolling_variance, 1, out.rolling_variance, 1);
                }
            }
            if(l.type == CONNECTED){
                axpy_cpu(l.outputs, 1, l.biases, 1, out.biases, 1);
@@ -81,6 +72,11 @@
            int num = l.n*l.c*l.size*l.size;
            scal_cpu(l.n, 1./n, l.biases, 1);
            scal_cpu(num, 1./n, l.weights, 1);
                if(l.batch_normalize){
                    scal_cpu(l.n, 1./n, l.scales, 1);
                    scal_cpu(l.n, 1./n, l.rolling_mean, 1);
                    scal_cpu(l.n, 1./n, l.rolling_variance, 1);
                }
        }
        if(l.type == CONNECTED){
            scal_cpu(l.outputs, 1./n, l.biases, 1);
@@ -125,6 +121,31 @@
    printf("Floating Point Operations: %.2f Bn\n", (float)ops/1000000000.);
}
void oneoff(char *cfgfile, char *weightfile, char *outfile)
{
    gpu_index = -1;
    network net = parse_network_cfg(cfgfile);
    int oldn = net.layers[net.n - 2].n;
    int c = net.layers[net.n - 2].c;
    net.layers[net.n - 2].n = 7879;
    net.layers[net.n - 2].biases += 5;
    net.layers[net.n - 2].weights += 5*c;
    if(weightfile){
        load_weights(&net, weightfile);
    }
    net.layers[net.n - 2].biases -= 5;
    net.layers[net.n - 2].weights -= 5*c;
    net.layers[net.n - 2].n = oldn;
    printf("%d\n", oldn);
    layer l = net.layers[net.n - 2];
    copy_cpu(l.n/3, l.biases, 1, l.biases +   l.n/3, 1);
    copy_cpu(l.n/3, l.biases, 1, l.biases + 2*l.n/3, 1);
    copy_cpu(l.n/3*l.c, l.weights, 1, l.weights +   l.n/3*l.c, 1);
    copy_cpu(l.n/3*l.c, l.weights, 1, l.weights + 2*l.n/3*l.c, 1);
    *net.seen = 0;
    save_weights(net, outfile);
}
void partial(char *cfgfile, char *weightfile, char *outfile, int max)
{
    gpu_index = -1;
@@ -387,8 +408,6 @@
        run_captcha(argc, argv);
    } else if (0 == strcmp(argv[1], "nightmare")){
        run_nightmare(argc, argv);
    } else if (0 == strcmp(argv[1], "change")){
        change_rate(argv[2], atof(argv[3]), (argc > 4) ? atof(argv[4]) : 0);
    } else if (0 == strcmp(argv[1], "rgbgr")){
        rgbgr_net(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "reset")){
@@ -404,7 +423,9 @@
    } else if (0 == strcmp(argv[1], "ops")){
        operations(argv[2]);
    } else if (0 == strcmp(argv[1], "speed")){
        speed(argv[2], (argc > 3) ? atoi(argv[3]) : 0);
        speed(argv[2], (argc > 3 && argv[3]) ? atoi(argv[3]) : 0);
    } else if (0 == strcmp(argv[1], "oneoff")){
        oneoff(argv[2], argv[3], argv[4]);
    } else if (0 == strcmp(argv[1], "partial")){
        partial(argv[2], argv[3], argv[4], atoi(argv[5]));
    } else if (0 == strcmp(argv[1], "average")){
src/detector.c
@@ -10,8 +10,9 @@
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
#endif
static int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int clear)
void train_detector(char *datacfg, char *cfgfile, char *weightfile, int *gpus, int ngpus, int clear)
{
    list *options = read_data_cfg(datacfg);
    char *train_images = option_find_str(options, "train", "data/train.list");
@@ -21,14 +22,28 @@
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    float avg_loss = -1;
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    network *nets = calloc(ngpus, sizeof(network));
    srand(time(0));
    int seed = rand();
    int i;
    for(i = 0; i < ngpus; ++i){
        srand(seed);
#ifdef GPU
        cuda_set_device(gpus[i]);
#endif
        nets[i] = parse_network_cfg(cfgfile);
        if(weightfile){
            load_weights(&nets[i], weightfile);
        }
        if(clear) *nets[i].seen = 0;
        nets[i].learning_rate *= ngpus;
    }
    if(clear) *net.seen = 0;
    srand(time(0));
    network net = nets[0];
    int imgs = net.batch * net.subdivisions * ngpus;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = net.batch*net.subdivisions;
    int i = *net.seen/imgs;
    data train, buffer;
    layer l = net.layers[net.n - 1];
@@ -62,37 +77,46 @@
    clock_t time;
    //while(i*imgs < N*120){
    while(get_current_batch(net) < net.max_batches){
        i += 1;
        time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
        load_thread = load_data(args);
/*
        int k;
        for(k = 0; k < l.max_boxes; ++k){
            box b = float_to_box(train.y.vals[10] + 1 + k*5);
            if(!b.x) break;
            printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h);
        }
        image im = float_to_image(448, 448, 3, train.X.vals[10]);
        int k;
        for(k = 0; k < l.max_boxes; ++k){
            box b = float_to_box(train.y.vals[10] + 1 + k*5);
            printf("%d %d %d %d\n", truth.x, truth.y, truth.w, truth.h);
            draw_bbox(im, b, 8, 1,0,0);
        }
        save_image(im, "truth11");
*/
        /*
           int k;
           for(k = 0; k < l.max_boxes; ++k){
           box b = float_to_box(train.y.vals[10] + 1 + k*5);
           if(!b.x) break;
           printf("loaded: %f %f %f %f\n", b.x, b.y, b.w, b.h);
           }
           image im = float_to_image(448, 448, 3, train.X.vals[10]);
           int k;
           for(k = 0; k < l.max_boxes; ++k){
           box b = float_to_box(train.y.vals[10] + 1 + k*5);
           printf("%d %d %d %d\n", truth.x, truth.y, truth.w, truth.h);
           draw_bbox(im, b, 8, 1,0,0);
           }
           save_image(im, "truth11");
         */
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
        float loss = train_network(net, train);
        float loss = 0;
#ifdef GPU
        if(ngpus == 1){
            loss = train_network(net, train);
        } else {
            loss = train_networks(nets, ngpus, train, 4);
        }
#else
        loss = train_network(net, train);
#endif
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
        i = get_current_batch(net);
        printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
        if(i%1000==0 || (i < 1000 && i%100 == 0)){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -105,6 +129,39 @@
    save_weights(net, buff);
}
static int get_coco_image_id(char *filename)
{
    char *p = strrchr(filename, '_');
    return atoi(p+1);
}
static void print_cocos(FILE *fp, char *image_path, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
{
    int i, j;
    int image_id = get_coco_image_id(image_path);
    for(i = 0; i < num_boxes; ++i){
        float xmin = boxes[i].x - boxes[i].w/2.;
        float xmax = boxes[i].x + boxes[i].w/2.;
        float ymin = boxes[i].y - boxes[i].h/2.;
        float ymax = boxes[i].y + boxes[i].h/2.;
        if (xmin < 0) xmin = 0;
        if (ymin < 0) ymin = 0;
        if (xmax > w) xmax = w;
        if (ymax > h) ymax = h;
        float bx = xmin;
        float by = ymin;
        float bw = xmax - xmin;
        float bh = ymax - ymin;
        for(j = 0; j < classes; ++j){
            if (probs[i][j]) fprintf(fp, "{\"image_id\":%d, \"category_id\":%d, \"bbox\":[%f, %f, %f, %f], \"score\":%f},\n", image_id, coco_ids[j], bx, by, bw, bh, probs[i][j]);
        }
    }
}
void print_detector_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
{
    int i, j;
@@ -131,8 +188,19 @@
    list *options = read_data_cfg(datacfg);
    char *valid_images = option_find_str(options, "valid", "data/train.list");
    char *name_list = option_find_str(options, "names", "data/names.list");
    char *prefix = option_find_str(options, "results", "results");
    char **names = get_labels(name_list);
    char buff[1024];
    int coco = option_find_int_quiet(options, "coco", 0);
    FILE *coco_fp = 0;
    if(coco){
        snprintf(buff, 1024, "%s/coco_results.json", prefix);
        coco_fp = fopen(buff, "w");
        fprintf(coco_fp, "[\n");
    }
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
@@ -141,7 +209,7 @@
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    char *base = "results/comp4_det_test_";
    char *base = "comp4_det_test_";
    list *plist = get_paths(valid_images);
    char **paths = (char **)list_to_array(plist);
@@ -151,8 +219,7 @@
    int j;
    FILE **fps = calloc(classes, sizeof(FILE *));
    for(j = 0; j < classes; ++j){
        char buff[1024];
        snprintf(buff, 1024, "%s%s.txt", base, names[j]);
        snprintf(buff, 1024, "%s/%s%s.txt", prefix, base, names[j]);
        fps[j] = fopen(buff, "w");
    }
    box *boxes = calloc(l.w*l.h*l.n, sizeof(box));
@@ -207,7 +274,11 @@
            int h = val[t].h;
            get_region_boxes(l, w, h, thresh, probs, boxes, 0);
            if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
            print_detector_detections(fps, id, boxes, probs, l.w*l.h*l.n, classes, w, h);
            if(coco_fp){
                print_cocos(coco_fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
            }else{
                print_detector_detections(fps, id, boxes, probs, l.w*l.h*l.n, classes, w, h);
            }
            free(id);
            free_image(val[t]);
            free_image(val_resized[t]);
@@ -216,6 +287,11 @@
    for(j = 0; j < classes; ++j){
        fclose(fps[j]);
    }
    if(coco_fp){
        fseek(coco_fp, -2, SEEK_CUR);
        fprintf(coco_fp, "\n]\n");
        fclose(coco_fp);
    }
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
@@ -300,8 +376,8 @@
void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh)
{
    list *options = read_data_cfg(datacfg);
        char *name_list = option_find_str(options, "names", "data/names.list");
        char **names = get_labels(name_list);
    char *name_list = option_find_str(options, "names", "data/names.list");
    char **names = get_labels(name_list);
    image **alphabet = load_alphabet();
    network net = parse_network_cfg(cfgfile);
@@ -361,6 +437,29 @@
        fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
        return;
    }
    char *gpu_list = find_char_arg(argc, argv, "-gpus", 0);
    int *gpus = 0;
    int gpu = 0;
    int ngpus = 0;
    if(gpu_list){
        printf("%s\n", gpu_list);
        int len = strlen(gpu_list);
        ngpus = 1;
        int i;
        for(i = 0; i < len; ++i){
            if (gpu_list[i] == ',') ++ngpus;
        }
        gpus = calloc(ngpus, sizeof(int));
        for(i = 0; i < ngpus; ++i){
            gpus[i] = atoi(gpu_list);
            gpu_list = strchr(gpu_list, ',')+1;
        }
    } else {
        gpu = gpu_index;
        gpus = &gpu;
        ngpus = 1;
    }
    int clear = find_arg(argc, argv, "-clear");
    char *datacfg = argv[3];
@@ -368,7 +467,7 @@
    char *weights = (argc > 5) ? argv[5] : 0;
    char *filename = (argc > 6) ? argv[6]: 0;
    if(0==strcmp(argv[2], "test")) test_detector(datacfg, cfg, weights, filename, thresh);
    else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, clear);
    else if(0==strcmp(argv[2], "train")) train_detector(datacfg, cfg, weights, gpus, ngpus, clear);
    else if(0==strcmp(argv[2], "valid")) validate_detector(datacfg, cfg, weights);
    else if(0==strcmp(argv[2], "recall")) validate_detector_recall(cfg, weights);
    else if(0==strcmp(argv[2], "demo")) {
src/region_layer.c
@@ -48,19 +48,18 @@
    return l;
}
#define LOG 1
#define DOABS 1
box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h)
{
    box b;
    b.x = (i + .5)/w + x[index + 0] * biases[2*n];
    b.y = (j + .5)/h + x[index + 1] * biases[2*n + 1];
    if(LOG){
        b.x = (i + logistic_activate(x[index + 0])) / w;
        b.y = (j + logistic_activate(x[index + 1])) / h;
    }
    b.x = (i + logistic_activate(x[index + 0])) / w;
    b.y = (j + logistic_activate(x[index + 1])) / h;
    b.w = exp(x[index + 2]) * biases[2*n];
    b.h = exp(x[index + 3]) * biases[2*n+1];
    if(DOABS){
        b.w = exp(x[index + 2]) * biases[2*n]   / w;
        b.h = exp(x[index + 3]) * biases[2*n+1] / h;
    }
    return b;
}
@@ -69,21 +68,17 @@
    box pred = get_region_box(x, biases, n, index, i, j, w, h);
    float iou = box_iou(pred, truth);
    float tx = (truth.x - (i + .5)/w) / biases[2*n];
    float ty = (truth.y - (j + .5)/h) / biases[2*n + 1];
    if(LOG){
        tx = (truth.x*w - i);
        ty = (truth.y*h - j);
    }
    float tx = (truth.x*w - i);
    float ty = (truth.y*h - j);
    float tw = log(truth.w / biases[2*n]);
    float th = log(truth.h / biases[2*n + 1]);
    delta[index + 0] = scale * (tx - x[index + 0]);
    delta[index + 1] = scale * (ty - x[index + 1]);
    if(LOG){
        delta[index + 0] = scale * (tx - logistic_activate(x[index + 0])) * logistic_gradient(logistic_activate(x[index + 0]));
        delta[index + 1] = scale * (ty - logistic_activate(x[index + 1])) * logistic_gradient(logistic_activate(x[index + 1]));
    if(DOABS){
        tw = log(truth.w*w / biases[2*n]);
        th = log(truth.h*h / biases[2*n + 1]);
    }
    delta[index + 0] = scale * (tx - logistic_activate(x[index + 0])) * logistic_gradient(logistic_activate(x[index + 0]));
    delta[index + 1] = scale * (ty - logistic_activate(x[index + 1])) * logistic_gradient(logistic_activate(x[index + 1]));
    delta[index + 2] = scale * (tw - x[index + 2]);
    delta[index + 3] = scale * (th - x[index + 3]);
    return iou;
@@ -135,9 +130,33 @@
        for(i = 0; i < l.h*l.w*l.n; ++i){
            int index = size*i + b*l.outputs;
            l.output[index + 4] = logistic_activate(l.output[index + 4]);
            if(l.softmax_tree){
        }
    }
    if (l.softmax_tree){
#ifdef GPU
        cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
        int i;
        int count = 5;
        for (i = 0; i < l.softmax_tree->groups; ++i) {
            int group_size = l.softmax_tree->group_size[i];
            softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count);
            count += group_size;
        }
        cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
#else
        for (b = 0; b < l.batch; ++b){
            for(i = 0; i < l.h*l.w*l.n; ++i){
                int index = size*i + b*l.outputs;
                softmax_tree(l.output + index + 5, 1, 0, 1, l.softmax_tree, l.output + index + 5);
            } else if(l.softmax){
            }
        }
#endif
    } else if (l.softmax){
        for (b = 0; b < l.batch; ++b){
            for(i = 0; i < l.h*l.w*l.n; ++i){
                int index = size*i + b*l.outputs;
                softmax(l.output + index + 5, l.classes, 1, l.output + index + 5);
            }
        }
@@ -188,11 +207,11 @@
                        truth.y = (j + .5)/l.h;
                        truth.w = l.biases[2*n];
                        truth.h = l.biases[2*n+1];
                        if(DOABS){
                            truth.w = l.biases[2*n]/l.w;
                            truth.h = l.biases[2*n+1]/l.h;
                        }
                        delta_region_box(truth, l.output, l.biases, n, index, i, j, l.w, l.h, l.delta, .01);
                        //l.delta[index + 0] = .1 * (0 - l.output[index + 0]);
                        //l.delta[index + 1] = .1 * (0 - l.output[index + 1]);
                        //l.delta[index + 2] = .1 * (0 - l.output[index + 2]);
                        //l.delta[index + 3] = .1 * (0 - l.output[index + 3]);
                    }
                }
            }
@@ -217,6 +236,10 @@
                if(l.bias_match){
                    pred.w = l.biases[2*n];
                    pred.h = l.biases[2*n+1];
                    if(DOABS){
                        pred.w = l.biases[2*n]/l.w;
                        pred.h = l.biases[2*n+1]/l.h;
                    }
                }
                //printf("pred: (%f, %f) %f x %f\n", pred.x, pred.y, pred.w, pred.h);
                pred.x = 0;