Joseph Redmon
2015-09-16 c53e03348c65462bcba33f6352087dd6b9268e8f
yolo working w/ regions
11 files modified
1 files renamed
330 ■■■■ changed files
Makefile 2 ●●● patch | view | raw | blame | history
cfg/darknet.cfg 10 ●●●●● patch | view | raw | blame | history
src/coco.c 14 ●●●●● patch | view | raw | blame | history
src/compare.c 4 ●●●● patch | view | raw | blame | history
src/darknet.c 6 ●●●● patch | view | raw | blame | history
src/data.c 4 ●●● patch | view | raw | blame | history
src/layer.h 3 ●●●●● patch | view | raw | blame | history
src/network.c 2 ●●● patch | view | raw | blame | history
src/network_kernels.cu 1 ●●●● patch | view | raw | blame | history
src/parser.c 4 ●●●● patch | view | raw | blame | history
src/region_layer.c 139 ●●●● patch | view | raw | blame | history
src/swag.c 141 ●●●●● patch | view | raw | blame | history
Makefile
@@ -34,7 +34,7 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o yoloplus.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif
cfg/darknet.cfg
@@ -1,15 +1,17 @@
[net]
batch=128
batch=256
subdivisions=1
height=256
width=256
channels=3
momentum=0.9
decay=0.0005
learning_rate=0.01
policy=poly
power=.5
max_batches=600000
policy=step
scale=.1
step=100000
max_batches=400000
[crop]
crop_height=224
src/coco.c
@@ -111,20 +111,6 @@
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
        if((i-1)*imgs <= N && i*imgs > N){
            fprintf(stderr, "First stage done\n");
            net.learning_rate *= 10;
            char buff[256];
            sprintf(buff, "%s/%s_first_stage.weights", backup_directory, base);
            save_weights(net, buff);
        }
        if((i-1)*imgs <= 80*N && i*imgs > N*80){
            fprintf(stderr, "Second stage done.\n");
            char buff[256];
            sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
            save_weights(net, buff);
        }
        if(i%1000==0){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
src/compare.c
@@ -175,8 +175,8 @@
    image im1 = load_image_color(box1.filename, net.w, net.h);
    image im2 = load_image_color(box2.filename, net.w, net.h);
    float *X  = calloc(net.w*net.h*net.c, sizeof(float));
    memcpy(X,                   im1.data, im1.w*im1.h*im1.c);
    memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c);
    memcpy(X,                   im1.data, im1.w*im1.h*im1.c*sizeof(float));
    memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
    float *predictions = network_predict(net, X);
    
    free_image(im1);
src/darknet.c
@@ -13,7 +13,7 @@
extern void run_imagenet(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
extern void run_yoloplus(int argc, char **argv);
extern void run_swag(int argc, char **argv);
extern void run_coco(int argc, char **argv);
extern void run_writing(int argc, char **argv);
extern void run_captcha(int argc, char **argv);
@@ -179,8 +179,8 @@
        average(argc, argv);
    } else if (0 == strcmp(argv[1], "yolo")){
        run_yolo(argc, argv);
    } else if (0 == strcmp(argv[1], "yoloplus")){
        run_yoloplus(argc, argv);
    } else if (0 == strcmp(argv[1], "swag")){
        run_swag(argc, argv);
    } else if (0 == strcmp(argv[1], "coco")){
        run_coco(argc, argv);
    } else if (0 == strcmp(argv[1], "compare")){
src/data.c
@@ -176,8 +176,10 @@
        int index = (col+row*num_boxes)*(5+classes);
        if (truth[index]) continue;
        truth[index++] = 1;
        if (classes) truth[index+id] = 1;
        if (id < classes) truth[index+id] = 1;
        index += classes;
        truth[index++] = x;
        truth[index++] = y;
        truth[index++] = w;
src/layer.h
@@ -30,6 +30,7 @@
    int batch;
    int inputs;
    int outputs;
    int truths;
    int h,w,c;
    int out_h, out_w, out_c;
    int n;
@@ -40,10 +41,12 @@
    int pad;
    int crop_width;
    int crop_height;
    int sqrt;
    int flip;
    float angle;
    float saturation;
    float exposure;
    int softmax;
    int classes;
    int coords;
    int background;
src/network.c
@@ -48,7 +48,7 @@
        case POLY:
            return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
        case SIG:
            return net.learning_rate * (1/(1+exp(net.gamma*(batch_num - net.step))));
            return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
        default:
            fprintf(stderr, "Policy is weird!\n");
            return net.learning_rate;
src/network_kernels.cu
@@ -134,6 +134,7 @@
    network_state state;
    int x_size = get_network_input_size(net)*net.batch;
    int y_size = get_network_output_size(net)*net.batch;
    if(net.layers[net.n-1].type == REGION) y_size = net.layers[net.n-1].truths*net.batch;
    if(!*net.input_gpu){
        *net.input_gpu = cuda_make_array(x, x_size);
        *net.truth_gpu = cuda_make_array(y, y_size);
src/parser.c
@@ -182,6 +182,10 @@
    int num = option_find_int(options, "num", 1);
    int side = option_find_int(options, "side", 7);
    region_layer layer = make_region_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
    int softmax = option_find_int(options, "softmax", 0);
    int sqrt = option_find_int(options, "sqrt", 0);
    layer.softmax = softmax;
    layer.sqrt = sqrt;
    return layer;
}
src/region_layer.c
@@ -22,15 +22,15 @@
    l.coords = coords;
    l.rescore = rescore;
    l.side = side;
    assert(side*side*l.coords*l.n == inputs);
    assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
    l.cost = calloc(1, sizeof(float));
    int outputs = l.n*5*side*side;
    l.outputs = outputs;
    l.output = calloc(batch*outputs, sizeof(float));
    l.delta = calloc(batch*inputs, sizeof(float));
    l.outputs = l.inputs;
    l.truths = l.side*l.side*(1+l.coords+l.classes);
    l.output = calloc(batch*l.outputs, sizeof(float));
    l.delta = calloc(batch*l.outputs, sizeof(float));
    #ifdef GPU
    l.output_gpu = cuda_make_array(l.output, batch*outputs);
    l.delta_gpu = cuda_make_array(l.delta, batch*inputs);
    l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
    l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
#endif
    fprintf(stderr, "Region Layer\n");
@@ -43,64 +43,69 @@
{
    int locations = l.side*l.side;
    int i,j;
    memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
    for(i = 0; i < l.batch*locations; ++i){
        for(j = 0; j < l.n; ++j){
            int in_index =  i*l.n*l.coords + j*l.coords;
            int out_index = i*l.n*5 + j*5;
            float prob =  state.input[in_index+0];
            float x =     state.input[in_index+1];
            float y =     state.input[in_index+2];
            float w =     state.input[in_index+3];
            float h =     state.input[in_index+4];
            /*
            float min_w = state.input[in_index+5];
            float max_w = state.input[in_index+6];
            float min_h = state.input[in_index+7];
            float max_h = state.input[in_index+8];
            */
            l.output[out_index+0] = prob;
            l.output[out_index+1] = x;
            l.output[out_index+2] = y;
            l.output[out_index+3] = w;
            l.output[out_index+4] = h;
        int index = i*((1+l.coords)*l.n + l.classes);
        if(l.softmax){
            activate_array(l.output + index, l.n*(1+l.coords), LOGISTIC);
            int offset = l.n*(1+l.coords);
            softmax_array(l.output + index + offset, l.classes,
                    l.output + index + offset);
        }
    }
    if(state.train){
        float avg_iou = 0;
        float avg_cat = 0;
        float avg_obj = 0;
        float avg_anyobj = 0;
        int count = 0;
        *(l.cost) = 0;
        int size = l.inputs * l.batch;
        memset(l.delta, 0, size * sizeof(float));
        for (i = 0; i < l.batch*locations; ++i) {
            int index = i*((1+l.coords)*l.n + l.classes);
            for(j = 0; j < l.n; ++j){
                int in_index = i*l.n*l.coords + j*l.coords;
                l.delta[in_index+0] = .1*(0-state.input[in_index+0]);
                int prob_index = index + j*(1 + l.coords);
                l.delta[prob_index] = (1./l.n)*(0-l.output[prob_index]);
                if(l.softmax){
                    l.delta[prob_index] = 1./(l.n*l.side)*(0-l.output[prob_index]);
                }
                *(l.cost) += (1./l.n)*pow(l.output[prob_index], 2);
                //printf("%f\n", l.output[prob_index]);
                avg_anyobj += l.output[prob_index];
            }
            int truth_index = i*5;
            int truth_index = i*(1 + l.coords + l.classes);
            int best_index = -1;
            float best_iou = 0;
            float best_rmse = 4;
            int bg = !state.truth[truth_index];
            if(bg) continue;
            if(bg) {
                continue;
            }
            box truth = {state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3], state.truth[truth_index+4]};
            int class_index = index + l.n*(1+l.coords);
            for(j = 0; j < l.classes; ++j) {
                l.delta[class_index+j] = state.truth[truth_index+1+j] - l.output[class_index+j];
                *(l.cost) += pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
                if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
            }
            truth_index += l.classes + 1;
            box truth = {state.truth[truth_index+0], state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3]};
            truth.x /= l.side;
            truth.y /= l.side;
            for(j = 0; j < l.n; ++j){
                int out_index = i*l.n*5 + j*5;
                int out_index = index + j*(1+l.coords);
                box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]};
                //printf("\n%f %f %f %f %f\n", l.output[out_index+0], out.x, out.y, out.w, out.h);
                out.x /= l.side;
                out.y /= l.side;
                if (l.sqrt){
                    out.w = out.w*out.w;
                    out.h = out.h*out.h;
                }
                float iou  = box_iou(out, truth);
                float rmse = box_rmse(out, truth);
@@ -116,46 +121,41 @@
                    }
                }
            }
            printf("%d", best_index);
            //int out_index = i*l.n*5 + best_index*5;
            //box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]};
            int in_index =  i*l.n*l.coords + best_index*l.coords;
            //printf("%d", best_index);
            int in_index = index + best_index*(1+l.coords);
            *(l.cost) -= pow(l.output[in_index], 2);
            *(l.cost) += pow(1-l.output[in_index], 2);
            avg_obj += l.output[in_index];
            l.delta[in_index+0] = (1.-l.output[in_index]);
            if(l.softmax){
                l.delta[in_index+0] = 5*(1.-l.output[in_index]);
            }
            //printf("%f\n", l.output[in_index]);
            l.delta[in_index+0] = (1-state.input[in_index+0]);
            l.delta[in_index+1] = state.truth[truth_index+1] - state.input[in_index+1];
            l.delta[in_index+2] = state.truth[truth_index+2] - state.input[in_index+2];
            l.delta[in_index+3] = state.truth[truth_index+3] - state.input[in_index+3];
            l.delta[in_index+4] = state.truth[truth_index+4] - state.input[in_index+4];
            /*
            l.delta[in_index+5] = 0 - state.input[in_index+5];
            l.delta[in_index+6] = 1 - state.input[in_index+6];
            l.delta[in_index+7] = 0 - state.input[in_index+7];
            l.delta[in_index+8] = 1 - state.input[in_index+8];
            */
            l.delta[in_index+1] = 5*(state.truth[truth_index+0] - l.output[in_index+1]);
            l.delta[in_index+2] = 5*(state.truth[truth_index+1] - l.output[in_index+2]);
            if(l.sqrt){
                l.delta[in_index+3] = 5*(sqrt(state.truth[truth_index+2]) - l.output[in_index+3]);
                l.delta[in_index+4] = 5*(sqrt(state.truth[truth_index+3]) - l.output[in_index+4]);
            }else{
                l.delta[in_index+3] = 5*(state.truth[truth_index+2] - l.output[in_index+3]);
                l.delta[in_index+4] = 5*(state.truth[truth_index+3] - l.output[in_index+4]);
            }
            /*
            float x =     state.input[in_index+1];
            float y =     state.input[in_index+2];
            float w =     state.input[in_index+3];
            float h =     state.input[in_index+4];
            float min_w = state.input[in_index+5];
            float max_w = state.input[in_index+6];
            float min_h = state.input[in_index+7];
            float max_h = state.input[in_index+8];
            */
            *(l.cost) += pow(1-best_iou, 2);
            avg_iou += best_iou;
            ++count;
            if(l.softmax){
                gradient_array(l.output + index, l.n*(1+l.coords), LOGISTIC, l.delta + index);
        }
        printf("\nAvg IOU: %f %d\n", avg_iou/count, count);
        }
        printf("Avg IOU: %f, Avg Cat Pred: %f, Avg Obj: %f, Avg Any: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
    }
}
void backward_region_layer(const region_layer l, network_state state)
{
    axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
    //copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1);
}
#ifdef GPU
@@ -165,8 +165,9 @@
    float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
    float *truth_cpu = 0;
    if(state.truth){
        truth_cpu = calloc(l.batch*l.outputs, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs);
        int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
        truth_cpu = calloc(num_truth, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, num_truth);
    }
    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
    network_state cpu_state;
src/swag.c
File was renamed from src/yoloplus.c
@@ -11,7 +11,7 @@
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
void draw_yoloplus(image im, float *box, int side, int objectness, char *label, float thresh)
void draw_swag(image im, float *box, int side, int objectness, char *label, float thresh)
{
    int classes = 20;
    int elems = 4+classes+objectness;
@@ -52,7 +52,7 @@
    show_image(im, label);
}
void train_yoloplus(char *cfgfile, char *weightfile)
void train_swag(char *cfgfile, char *weightfile)
{
    char *train_images = "/home/pjreddie/data/voc/test/train.txt";
    char *backup_directory = "/home/pjreddie/backup/";
@@ -65,23 +65,20 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    detection_layer layer = get_network_detection_layer(net);
    int imgs = 128;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = net.batch*net.subdivisions;
    int i = *net.seen/imgs;
    char **paths;
    list *plist = get_paths(train_images);
    int N = plist->size;
    paths = (char **)list_to_array(plist);
    if(i*imgs > N*120){
        net.layers[net.n-1].rescore = 1;
    }
    data train, buffer;
    int classes = layer.classes;
    int background = layer.objectness;
    int side = sqrt(get_detection_layer_locations(layer));
    layer l = net.layers[net.n - 1];
    int side = l.side;
    int classes = l.classes;
    list *plist = get_paths(train_images);
    int N = plist->size;
    char **paths = (char **)list_to_array(plist);
    load_args args = {0};
    args.w = net.w;
@@ -91,12 +88,12 @@
    args.m = plist->size;
    args.classes = classes;
    args.num_boxes = side;
    args.background = background;
    args.d = &buffer;
    args.type = DETECTION_DATA;
    args.type = REGION_DATA;
    pthread_t load_thread = load_data_in_thread(args);
    clock_t time;
    //while(i*imgs < N*120){
    while(get_current_batch(net) < net.max_batches){
        i += 1;
        time=clock();
@@ -105,36 +102,21 @@
        load_thread = load_data_in_thread(args);
        printf("Loaded: %lf seconds\n", sec(clock()-time));
/*
        image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
        image copy = copy_image(im);
        draw_swag(copy, train.y.vals[113], 7, "truth");
        cvWaitKey(0);
        free_image(copy);
        */
        time=clock();
        float loss = train_network(net, train);
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %f rate, %d images, epoch: %f\n", get_current_batch(net), loss, avg_loss, sec(clock()-time), get_current_rate(net), *net.seen, (float)*net.seen/N);
        if((i-1)*imgs <= 80*N && i*imgs > N*80){
            fprintf(stderr, "Second stage done.\n");
            char buff[256];
            sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
            save_weights(net, buff);
            net.layers[net.n-1].joint = 1;
            net.layers[net.n-1].objectness = 0;
            background = 0;
            pthread_join(load_thread, 0);
            free_data(buffer);
            args.background = background;
            load_thread = load_data_in_thread(args);
        }
        if((i-1)*imgs <= 120*N && i*imgs > N*120){
            fprintf(stderr, "Third stage done.\n");
            char buff[256];
            sprintf(buff, "%s/%s_final.weights", backup_directory, base);
            net.layers[net.n-1].rescore = 1;
            save_weights(net, buff);
        }
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
        if(i%1000==0){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -143,36 +125,38 @@
        free_data(train);
    }
    char buff[256];
    sprintf(buff, "%s/%s_rescore.weights", backup_directory, base);
    sprintf(buff, "%s/%s_final.weights", backup_directory, base);
    save_weights(net, buff);
}
void convert_yoloplus_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes)
{
    int i,j;
    int per_box = 4+classes+(background || objectness);
    for (i = 0; i < num_boxes*num_boxes; ++i){
        float scale = 1;
        if(objectness) scale = 1-predictions[i*per_box];
        int offset = i*per_box+(background||objectness);
    int i,j,n;
    int per_cell = 5*num+classes;
    for (i = 0; i < side*side; ++i){
        int row = i / side;
        int col = i % side;
        for(n = 0; n < num; ++n){
            int offset = i*per_cell + 5*n;
            float scale = predictions[offset];
            int index = i*num + n;
            boxes[index].x = (predictions[offset + 1] + col) / side * w;
            boxes[index].y = (predictions[offset + 2] + row) / side * h;
            boxes[index].w = pow(predictions[offset + 3], (square?2:1)) * w;
            boxes[index].h = pow(predictions[offset + 4], (square?2:1)) * h;
        for(j = 0; j < classes; ++j){
                offset = i*per_cell + 5*num;
            float prob = scale*predictions[offset+j];
            probs[i][j] = (prob > thresh) ? prob : 0;
                probs[index][j] = (prob > thresh) ? prob : 0;
        }
        int row = i / num_boxes;
        int col = i % num_boxes;
        offset += classes;
        boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
        boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
        boxes[i].w = pow(predictions[offset + 2], 2) * w;
        boxes[i].h = pow(predictions[offset + 3], 2) * h;
        }
    }
}
void print_yoloplus_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
void print_swag_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
{
    int i, j;
    for(i = 0; i < num_boxes*num_boxes; ++i){
    for(i = 0; i < total; ++i){
        float xmin = boxes[i].x - boxes[i].w/2.;
        float xmax = boxes[i].x + boxes[i].w/2.;
        float ymin = boxes[i].y - boxes[i].h/2.;
@@ -190,14 +174,13 @@
    }
}
void validate_yoloplus(char *cfgfile, char *weightfile)
void validate_swag(char *cfgfile, char *weightfile)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    detection_layer layer = get_network_detection_layer(net);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
@@ -205,10 +188,10 @@
    list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
    char **paths = (char **)list_to_array(plist);
    int classes = layer.classes;
    int objectness = layer.objectness;
    int background = layer.background;
    int num_boxes = sqrt(get_detection_layer_locations(layer));
    layer l = net.layers[net.n-1];
    int classes = l.classes;
    int square = l.sqrt;
    int side = l.side;
    int j;
    FILE **fps = calloc(classes, sizeof(FILE *));
@@ -217,9 +200,9 @@
        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
        fps[j] = fopen(buff, "w");
    }
    box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
    float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
    for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
    box *boxes = calloc(side*side*l.n, sizeof(box));
    float **probs = calloc(side*side*l.n, sizeof(float *));
    for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
    int m = plist->size;
    int i=0;
@@ -268,9 +251,9 @@
            float *predictions = network_predict(net, X);
            int w = val[t].w;
            int h = val[t].h;
            convert_yoloplus_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
            if (nms) do_nms(boxes, probs, num_boxes*num_boxes, classes, iou_thresh);
            print_yoloplus_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
            convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes);
            if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh);
            print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
            free(id);
            free_image(val[t]);
            free_image(val_resized[t]);
@@ -279,7 +262,7 @@
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
void test_yoloplus(char *cfgfile, char *weightfile, char *filename, float thresh)
void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh)
{
    network net = parse_network_cfg(cfgfile);
@@ -306,7 +289,7 @@
        time=clock();
        float *predictions = network_predict(net, X);
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        draw_yoloplus(im, predictions, 7, layer.objectness, "predictions", thresh);
        draw_swag(im, predictions, 7, layer.objectness, "predictions", thresh);
        free_image(im);
        free_image(sized);
#ifdef OPENCV
@@ -317,7 +300,7 @@
    }
}
void run_yoloplus(int argc, char **argv)
void run_swag(int argc, char **argv)
{
    float thresh = find_float_arg(argc, argv, "-thresh", .2);
    if(argc < 4){
@@ -328,7 +311,7 @@
    char *cfg = argv[3];
    char *weights = (argc > 4) ? argv[4] : 0;
    char *filename = (argc > 5) ? argv[5]: 0;
    if(0==strcmp(argv[2], "test")) test_yoloplus(cfg, weights, filename, thresh);
    else if(0==strcmp(argv[2], "train")) train_yoloplus(cfg, weights);
    else if(0==strcmp(argv[2], "valid")) validate_yoloplus(cfg, weights);
    if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh);
    else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
    else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights);
}