Joseph Redmon
2015-11-09 8c5364f58569eaeb5582a4915b36b24fc5570c76
New YOLO
13 files modified
1 files renamed
4 files deleted
2084 ■■■■ changed files
Makefile 5 ●●●●● patch | view | raw | blame | history
cfg/yolo.cfg 6 ●●●● patch | view | raw | blame | history
data/scream.jpg patch | view | raw | blame | history
src/coco.c 4 ●●●● patch | view | raw | blame | history
src/darknet.c 3 ●●●●● patch | view | raw | blame | history
src/detection_layer.c 296 ●●●● patch | view | raw | blame | history
src/detection_layer.h 8 ●●●●● patch | view | raw | blame | history
src/layer.h 4 ●●●● patch | view | raw | blame | history
src/network.c 11 ●●●●● patch | view | raw | blame | history
src/network.h 1 ●●●● patch | view | raw | blame | history
src/network_kernels.cu 7 ●●●● patch | view | raw | blame | history
src/old.c 607 ●●●●● patch | view | raw | blame | history
src/parser.c 28 ●●●●● patch | view | raw | blame | history
src/region_layer.c 259 ●●●●● patch | view | raw | blame | history
src/region_layer.h 18 ●●●●● patch | view | raw | blame | history
src/swag.c 460 ●●●●● patch | view | raw | blame | history
src/yolo.c 352 ●●●●● patch | view | raw | blame | history
src/yolo_kernels.cu 15 ●●●● patch | view | raw | blame | history
Makefile
@@ -3,7 +3,6 @@
DEBUG=0
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20
ARCH= -arch=sm_52 --use_fast_math
VPATH=./src/
EXEC=darknet
@@ -35,9 +34,9 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o swag_kernels.o
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o yolo_kernels.o
endif
OBJS = $(addprefix $(OBJDIR), $(OBJ))
cfg/yolo.cfg
@@ -9,8 +9,8 @@
learning_rate=0.001
policy=steps
steps=100,200,300,400,500,600,700,20000,30000
scales=2,2,1.25,1.25,1.25,1.25,1.03,.1,.1
steps=200,400,600,20000,30000
scales=2.5,2,2,.1,.1
max_batches = 40000
[crop]
@@ -218,7 +218,7 @@
output= 1470
activation=linear
[region]
[detection]
classes=20
coords=4
rescore=1
data/scream.jpg

src/coco.c
@@ -1,7 +1,7 @@
#include <stdio.h>
#include "network.h"
#include "region_layer.h"
#include "detection_layer.h"
#include "cost_layer.h"
#include "utils.h"
#include "parser.h"
@@ -366,7 +366,7 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    region_layer l = net.layers[net.n-1];
    detection_layer l = net.layers[net.n-1];
    set_batch_network(&net, 1);
    srand(2222222);
    clock_t time;
src/darknet.c
@@ -13,7 +13,6 @@
extern void run_imagenet(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
extern void run_swag(int argc, char **argv);
extern void run_coco(int argc, char **argv);
extern void run_writing(int argc, char **argv);
extern void run_captcha(int argc, char **argv);
@@ -221,8 +220,6 @@
        average(argc, argv);
    } else if (0 == strcmp(argv[1], "yolo")){
        run_yolo(argc, argv);
    } else if (0 == strcmp(argv[1], "swag")){
        run_swag(argc, argv);
    } else if (0 == strcmp(argv[1], "coco")){
        run_coco(argc, argv);
    } else if (0 == strcmp(argv[1], "classifier")){
src/detection_layer.c
@@ -6,42 +6,32 @@
#include "cuda.h"
#include "utils.h"
#include <stdio.h>
#include <assert.h>
#include <string.h>
#include <stdlib.h>
int get_detection_layer_locations(detection_layer l)
{
    return l.inputs / (l.classes+l.coords+l.joint+(l.background || l.objectness));
}
int get_detection_layer_output_size(detection_layer l)
{
    return get_detection_layer_locations(l)*((l.background || l.objectness) + l.classes + l.coords);
}
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness)
detection_layer make_detection_layer(int batch, int inputs, int n, int side, int classes, int coords, int rescore)
{
    detection_layer l = {0};
    l.type = DETECTION;
    l.n = n;
    l.batch = batch;
    l.inputs = inputs;
    l.classes = classes;
    l.coords = coords;
    l.rescore = rescore;
    l.objectness = objectness;
    l.background = background;
    l.joint = joint;
    l.side = side;
    assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
    l.cost = calloc(1, sizeof(float));
    l.does_cost=1;
    int outputs = get_detection_layer_output_size(l);
    l.outputs = outputs;
    l.output = calloc(batch*outputs, sizeof(float));
    l.delta = calloc(batch*outputs, sizeof(float));
    #ifdef GPU
    l.output_gpu = cuda_make_array(l.output, batch*outputs);
    l.delta_gpu  = cuda_make_array(l.delta,  batch*outputs);
    #endif
    l.outputs = l.inputs;
    l.truths = l.side*l.side*(1+l.coords+l.classes);
    l.output = calloc(batch*l.outputs, sizeof(float));
    l.delta = calloc(batch*l.outputs, sizeof(float));
#ifdef GPU
    l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
    l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
#endif
    fprintf(stderr, "Detection Layer\n");
    srand(0);
@@ -51,124 +41,164 @@
void forward_detection_layer(const detection_layer l, network_state state)
{
    int in_i = 0;
    int out_i = 0;
    int locations = get_detection_layer_locations(l);
    int locations = l.side*l.side;
    int i,j;
    for(i = 0; i < l.batch*locations; ++i){
        int mask = (!state.truth || state.truth[out_i + (l.background || l.objectness) + l.classes + 2]);
        float scale = 1;
        if(l.joint) scale = state.input[in_i++];
        else if(l.objectness){
            l.output[out_i++] = 1-state.input[in_i++];
            scale = mask;
        }
        else if(l.background) l.output[out_i++] = scale*state.input[in_i++];
        for(j = 0; j < l.classes; ++j){
            l.output[out_i++] = scale*state.input[in_i++];
        }
        if(l.objectness){
        }else if(l.background){
            softmax_array(l.output + out_i - l.classes-l.background, l.classes+l.background, l.output + out_i - l.classes-l.background);
            activate_array(state.input+in_i, l.coords, LOGISTIC);
        }
        for(j = 0; j < l.coords; ++j){
            l.output[out_i++] = mask*state.input[in_i++];
    memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
    int b;
    if (l.softmax){
        for(b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            for (i = 0; i < locations; ++i) {
                int offset = i*l.classes;
                softmax_array(l.output + index + offset, l.classes,
                        l.output + index + offset);
            }
            int offset = locations*l.classes;
            activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
        }
    }
    float avg_iou = 0;
    int count = 0;
    if(l.does_cost && state.train){
    if(state.train){
        float avg_iou = 0;
        float avg_cat = 0;
        float avg_allcat = 0;
        float avg_obj = 0;
        float avg_anyobj = 0;
        int count = 0;
        *(l.cost) = 0;
        int size = get_detection_layer_output_size(l) * l.batch;
        int size = l.inputs * l.batch;
        memset(l.delta, 0, size * sizeof(float));
        for (i = 0; i < l.batch*locations; ++i) {
            int classes = (l.objectness || l.background)+l.classes;
            int offset = i*(classes+l.coords);
            for (j = offset; j < offset+classes; ++j) {
                *(l.cost) += pow(state.truth[j] - l.output[j], 2);
                l.delta[j] =  state.truth[j] - l.output[j];
                if(l.background && j == offset) l.delta[j] *= .1;
            }
            box truth;
            truth.x = state.truth[j+0]/7;
            truth.y = state.truth[j+1]/7;
            truth.w = pow(state.truth[j+2], 2);
            truth.h = pow(state.truth[j+3], 2);
            box out;
            out.x = l.output[j+0]/7;
            out.y = l.output[j+1]/7;
            out.w = pow(l.output[j+2], 2);
            out.h = pow(l.output[j+3], 2);
            if(!(truth.w*truth.h)) continue;
            float iou = box_iou(out, truth);
            avg_iou += iou;
            ++count;
            *(l.cost) += pow((1-iou), 2);
            l.delta[j+0] = 4 * (state.truth[j+0] - l.output[j+0]);
            l.delta[j+1] = 4 * (state.truth[j+1] - l.output[j+1]);
            l.delta[j+2] = 4 * (state.truth[j+2] - l.output[j+2]);
            l.delta[j+3] = 4 * (state.truth[j+3] - l.output[j+3]);
            if(l.rescore){
                if(l.objectness){
                    state.truth[offset] = iou;
                    l.delta[offset] = state.truth[offset] - l.output[offset];
        for (b = 0; b < l.batch; ++b){
            int index = b*l.inputs;
            for (i = 0; i < locations; ++i) {
                int truth_index = (b*locations + i)*(1+l.coords+l.classes);
                int is_obj = state.truth[truth_index];
                for (j = 0; j < l.n; ++j) {
                    int p_index = index + locations*l.classes + i*l.n + j;
                    l.delta[p_index] = l.noobject_scale*(0 - l.output[p_index]);
                    *(l.cost) += l.noobject_scale*pow(l.output[p_index], 2);
                    avg_anyobj += l.output[p_index];
                }
                else{
                    for (j = offset; j < offset+classes; ++j) {
                        if(state.truth[j]) state.truth[j] = iou;
                        l.delta[j] =  state.truth[j] - l.output[j];
                int best_index = -1;
                float best_iou = 0;
                float best_rmse = 20;
                if (!is_obj){
                    continue;
                }
                int class_index = index + i*l.classes;
                for(j = 0; j < l.classes; ++j) {
                    l.delta[class_index+j] = l.class_scale * (state.truth[truth_index+1+j] - l.output[class_index+j]);
                    *(l.cost) += l.class_scale * pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
                    if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
                    avg_allcat += l.output[class_index+j];
                }
                box truth = float_to_box(state.truth + truth_index + 1 + l.classes);
                truth.x /= l.side;
                truth.y /= l.side;
                for(j = 0; j < l.n; ++j){
                    int box_index = index + locations*(l.classes + l.n) + (i*l.n + j) * l.coords;
                    box out = float_to_box(l.output + box_index);
                    out.x /= l.side;
                    out.y /= l.side;
                    if (l.sqrt){
                        out.w = out.w*out.w;
                        out.h = out.h*out.h;
                    }
                    float iou  = box_iou(out, truth);
                    //iou = 0;
                    float rmse = box_rmse(out, truth);
                    if(best_iou > 0 || iou > 0){
                        if(iou > best_iou){
                            best_iou = iou;
                            best_index = j;
                        }
                    }else{
                        if(rmse < best_rmse){
                            best_rmse = rmse;
                            best_index = j;
                        }
                    }
                }
                if(l.forced){
                    if(truth.w*truth.h < .1){
                        best_index = 1;
                    }else{
                        best_index = 0;
                    }
                }
                int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords;
                int tbox_index = truth_index + 1 + l.classes;
                box out = float_to_box(l.output + box_index);
                out.x /= l.side;
                out.y /= l.side;
                if (l.sqrt) {
                    out.w = out.w*out.w;
                    out.h = out.h*out.h;
                }
                float iou  = box_iou(out, truth);
                //printf("%d", best_index);
                int p_index = index + locations*l.classes + i*l.n + best_index;
                *(l.cost) -= l.noobject_scale * pow(l.output[p_index], 2);
                *(l.cost) += l.object_scale * pow(1-l.output[p_index], 2);
                avg_obj += l.output[p_index];
                l.delta[p_index] = l.object_scale * (1.-l.output[p_index]);
                if(l.rescore){
                    l.delta[p_index] = l.object_scale * (iou - l.output[p_index]);
                }
                l.delta[box_index+0] = l.coord_scale*(state.truth[tbox_index + 0] - l.output[box_index + 0]);
                l.delta[box_index+1] = l.coord_scale*(state.truth[tbox_index + 1] - l.output[box_index + 1]);
                l.delta[box_index+2] = l.coord_scale*(state.truth[tbox_index + 2] - l.output[box_index + 2]);
                l.delta[box_index+3] = l.coord_scale*(state.truth[tbox_index + 3] - l.output[box_index + 3]);
                if(l.sqrt){
                    l.delta[box_index+2] = l.coord_scale*(sqrt(state.truth[tbox_index + 2]) - l.output[box_index + 2]);
                    l.delta[box_index+3] = l.coord_scale*(sqrt(state.truth[tbox_index + 3]) - l.output[box_index + 3]);
                }
                *(l.cost) += pow(1-iou, 2);
                avg_iou += iou;
                ++count;
            }
            if(l.softmax){
                gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords),
                        LOGISTIC, l.delta + index + locations*l.classes);
            }
        }
        printf("Avg IOU: %f\n", avg_iou/count);
        printf("Detection Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
    }
}
void backward_detection_layer(const detection_layer l, network_state state)
{
    int locations = get_detection_layer_locations(l);
    int i,j;
    int in_i = 0;
    int out_i = 0;
    for(i = 0; i < l.batch*locations; ++i){
        float scale = 1;
        float latent_delta = 0;
        if(l.joint) scale = state.input[in_i++];
        else if (l.objectness)   state.delta[in_i++] += -l.delta[out_i++];
        else if (l.background) state.delta[in_i++] += scale*l.delta[out_i++];
        for(j = 0; j < l.classes; ++j){
            latent_delta += state.input[in_i]*l.delta[out_i];
            state.delta[in_i++] += scale*l.delta[out_i++];
        }
        if (l.objectness) {
        }else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
        for (j = 0; j < l.coords; ++j){
            state.delta[in_i++] += l.delta[out_i++];
        }
        if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] += latent_delta;
    }
    axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
}
#ifdef GPU
void forward_detection_layer_gpu(const detection_layer l, network_state state)
{
    int outputs = get_detection_layer_output_size(l);
    if(!state.train){
        copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1);
        return;
    }
    float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
    float *truth_cpu = 0;
    if(state.truth){
        truth_cpu = calloc(l.batch*outputs, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
        int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
        truth_cpu = calloc(num_truth, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, num_truth);
    }
    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
    network_state cpu_state;
@@ -176,38 +206,16 @@
    cpu_state.truth = truth_cpu;
    cpu_state.input = in_cpu;
    forward_detection_layer(l, cpu_state);
    cuda_push_array(l.output_gpu, l.output, l.batch*outputs);
    cuda_push_array(l.delta_gpu, l.delta, l.batch*outputs);
    cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
    cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs);
    free(cpu_state.input);
    if(cpu_state.truth) free(cpu_state.truth);
}
void backward_detection_layer_gpu(detection_layer l, network_state state)
{
    int outputs = get_detection_layer_output_size(l);
    float *in_cpu    = calloc(l.batch*l.inputs, sizeof(float));
    float *delta_cpu = calloc(l.batch*l.inputs, sizeof(float));
    float *truth_cpu = 0;
    if(state.truth){
        truth_cpu = calloc(l.batch*outputs, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
    }
    network_state cpu_state;
    cpu_state.train = state.train;
    cpu_state.input = in_cpu;
    cpu_state.truth = truth_cpu;
    cpu_state.delta = delta_cpu;
    cuda_pull_array(state.input, in_cpu,    l.batch*l.inputs);
    cuda_pull_array(state.delta, delta_cpu, l.batch*l.inputs);
    cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs);
    backward_detection_layer(l, cpu_state);
    cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs);
    if (truth_cpu) free(truth_cpu);
    free(in_cpu);
    free(delta_cpu);
    axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
    //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
}
#endif
src/detection_layer.h
@@ -1,16 +1,14 @@
#ifndef DETECTION_LAYER_H
#define DETECTION_LAYER_H
#ifndef REGION_LAYER_H
#define REGION_LAYER_H
#include "params.h"
#include "layer.h"
typedef layer detection_layer;
detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness);
detection_layer make_detection_layer(int batch, int inputs, int n, int size, int classes, int coords, int rescore);
void forward_detection_layer(const detection_layer l, network_state state);
void backward_detection_layer(const detection_layer l, network_state state);
int get_detection_layer_output_size(detection_layer l);
int get_detection_layer_locations(detection_layer l);
#ifdef GPU
void forward_detection_layer_gpu(const detection_layer l, network_state state);
src/layer.h
@@ -15,7 +15,6 @@
    ROUTE,
    COST,
    NORMALIZATION,
    REGION,
    AVGPOOL
} LAYER_TYPE;
@@ -30,9 +29,6 @@
    int batch_normalize;
    int batch;
    int forced;
    int object_logistic;
    int class_logistic;
    int coord_logistic;
    int inputs;
    int outputs;
    int truths;
src/network.c
@@ -11,7 +11,6 @@
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
#include "region_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
@@ -72,8 +71,6 @@
            return "softmax";
        case DETECTION:
            return "detection";
        case REGION:
            return "region";
        case DROPOUT:
            return "dropout";
        case CROP:
@@ -119,8 +116,6 @@
            forward_normalization_layer(l, state);
        } else if(l.type == DETECTION){
            forward_detection_layer(l, state);
        } else if(l.type == REGION){
            forward_region_layer(l, state);
        } else if(l.type == CONNECTED){
            forward_connected_layer(l, state);
        } else if(l.type == CROP){
@@ -180,10 +175,6 @@
            sum += net.layers[i].cost[0];
            ++count;
        }
        if(net.layers[i].type == REGION){
            sum += net.layers[i].cost[0];
            ++count;
        }
    }
    return sum/count;
}
@@ -224,8 +215,6 @@
            backward_dropout_layer(l, state);
        } else if(l.type == DETECTION){
            backward_detection_layer(l, state);
        } else if(l.type == REGION){
            backward_region_layer(l, state);
        } else if(l.type == SOFTMAX){
            if(i != 0) backward_softmax_layer(l, state);
        } else if(l.type == CONNECTED){
src/network.h
@@ -89,7 +89,6 @@
void set_batch_network(network *net, int b);
int get_network_input_size(network net);
float get_network_cost(network net);
detection_layer get_network_detection_layer(network net);
int get_network_nuisance(network net);
int get_network_background(network net);
src/network_kernels.cu
@@ -13,7 +13,6 @@
#include "crop_layer.h"
#include "connected_layer.h"
#include "detection_layer.h"
#include "region_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
@@ -44,8 +43,6 @@
            forward_deconvolutional_layer_gpu(l, state);
        } else if(l.type == DETECTION){
            forward_detection_layer_gpu(l, state);
        } else if(l.type == REGION){
            forward_region_layer_gpu(l, state);
        } else if(l.type == CONNECTED){
            forward_connected_layer_gpu(l, state);
        } else if(l.type == CROP){
@@ -96,8 +93,6 @@
            backward_dropout_layer_gpu(l, state);
        } else if(l.type == DETECTION){
            backward_detection_layer_gpu(l, state);
        } else if(l.type == REGION){
            backward_region_layer_gpu(l, state);
        } else if(l.type == NORMALIZATION){
            backward_normalization_layer_gpu(l, state);
        } else if(l.type == SOFTMAX){
@@ -134,7 +129,7 @@
    network_state state;
    int x_size = get_network_input_size(net)*net.batch;
    int y_size = get_network_output_size(net)*net.batch;
    if(net.layers[net.n-1].type == REGION) y_size = net.layers[net.n-1].truths*net.batch;
    if(net.layers[net.n-1].type == DETECTION) y_size = net.layers[net.n-1].truths*net.batch;
    if(!*net.input_gpu){
        *net.input_gpu = cuda_make_array(x, x_size);
        *net.truth_gpu = cuda_make_array(y, y_size);
src/old.c
File was deleted
src/parser.c
@@ -14,7 +14,6 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "region_layer.h"
#include "avgpool_layer.h"
#include "route_layer.h"
#include "list.h"
@@ -38,7 +37,6 @@
int is_crop(section *s);
int is_cost(section *s);
int is_detection(section *s);
int is_region(section *s);
int is_route(section *s);
list *read_cfg(char *filename);
@@ -168,35 +166,19 @@
    int coords = option_find_int(options, "coords", 1);
    int classes = option_find_int(options, "classes", 1);
    int rescore = option_find_int(options, "rescore", 0);
    int joint = option_find_int(options, "joint", 0);
    int objectness = option_find_int(options, "objectness", 0);
    int background = option_find_int(options, "background", 0);
    detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, joint, rescore, background, objectness);
    return layer;
}
region_layer parse_region(list *options, size_params params)
{
    int coords = option_find_int(options, "coords", 1);
    int classes = option_find_int(options, "classes", 1);
    int rescore = option_find_int(options, "rescore", 0);
    int num = option_find_int(options, "num", 1);
    int side = option_find_int(options, "side", 7);
    region_layer layer = make_region_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
    detection_layer layer = make_detection_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
    layer.softmax = option_find_int(options, "softmax", 0);
    layer.sqrt = option_find_int(options, "sqrt", 0);
    layer.object_logistic = option_find_int(options, "object_logistic", 0);
    layer.class_logistic = option_find_int(options, "class_logistic", 0);
    layer.coord_logistic = option_find_int(options, "coord_logistic", 0);
    layer.coord_scale = option_find_float(options, "coord_scale", 1);
    layer.forced = option_find_int(options, "forced", 0);
    layer.object_scale = option_find_float(options, "object_scale", 1);
    layer.noobject_scale = option_find_float(options, "noobject_scale", 1);
    layer.class_scale = option_find_float(options, "class_scale", 1);
    layer.jitter = option_find_float(options, "jitter", .1);
    layer.jitter = option_find_float(options, "jitter", .2);
    return layer;
}
@@ -430,8 +412,6 @@
            l = parse_cost(options, params);
        }else if(is_detection(s)){
            l = parse_detection(options, params);
        }else if(is_region(s)){
            l = parse_region(options, params);
        }else if(is_softmax(s)){
            l = parse_softmax(options, params);
        }else if(is_normalization(s)){
@@ -485,10 +465,6 @@
{
    return (strcmp(s->type, "[detection]")==0);
}
int is_region(section *s)
{
    return (strcmp(s->type, "[region]")==0);
}
int is_deconvolutional(section *s)
{
    return (strcmp(s->type, "[deconv]")==0
src/region_layer.c
File was deleted
src/region_layer.h
File was deleted
src/swag.c
File was deleted
src/yolo.c
@@ -9,44 +9,36 @@
#include "opencv2/highgui/highgui_c.h"
#endif
char *voc_class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
void draw_yolo(image im, float *box, int side, int objectness, char *label, float thresh)
void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label)
{
    int classes = 20;
    int elems = 4+classes+objectness;
    int j;
    int r, c;
    int i;
    for(r = 0; r < side; ++r){
        for(c = 0; c < side; ++c){
            j = (r*side + c) * elems;
            float scale = 1;
            if(objectness) scale = 1 - box[j++];
            int class = max_index(box+j, classes);
            if(scale * box[j+class] > thresh){
                int width = sqrt(scale*box[j+class])*5 + 1;
                printf("%f %s\n", scale * box[j+class], voc_class_names[class]);
                float red = get_color(0,class,classes);
                float green = get_color(1,class,classes);
                float blue = get_color(2,class,classes);
    for(i = 0; i < num; ++i){
        int class = max_index(probs[i], classes);
        float prob = probs[i][class];
        if(prob > thresh){
            int width = pow(prob, 1./2.)*10;
            printf("%f %s\n", prob, voc_names[class]);
            float red = get_color(0,class,classes);
            float green = get_color(1,class,classes);
            float blue = get_color(2,class,classes);
            //red = green = blue = 0;
            box b = boxes[i];
                j += classes;
                float x = box[j+0];
                float y = box[j+1];
                x = (x+c)/side;
                y = (y+r)/side;
                float w = box[j+2]; //*maxwidth;
                float h = box[j+3]; //*maxheight;
                h = h*h;
                w = w*w;
            int left  = (b.x-b.w/2.)*im.w;
            int right = (b.x+b.w/2.)*im.w;
            int top   = (b.y-b.h/2.)*im.h;
            int bot   = (b.y+b.h/2.)*im.h;
                int left  = (x-w/2)*im.w;
                int right = (x+w/2)*im.w;
                int top   = (y-h/2)*im.h;
                int bot   = (y+h/2)*im.h;
                draw_box_width(im, left, top, right, bot, width, red, green, blue);
            }
            if(left < 0) left = 0;
            if(right > im.w-1) right = im.w-1;
            if(top < 0) top = 0;
            if(bot > im.h-1) bot = im.h-1;
            draw_box_width(im, left, top, right, bot, width, red, green, blue);
        }
    }
    show_image(im, label);
@@ -54,7 +46,13 @@
void train_yolo(char *cfgfile, char *weightfile)
{
    //char *train_images = "/home/pjreddie/data/voc/person_detection/2010_person.txt";
    //char *train_images = "/home/pjreddie/data/people-art/train.txt";
    //char *train_images = "/home/pjreddie/data/voc/test/2012_trainval.txt";
    //char *train_images = "/home/pjreddie/data/voc/test/2010_trainval.txt";
    char *train_images = "/home/pjreddie/data/voc/test/train.txt";
    //char *train_images = "/home/pjreddie/data/voc/test/train_all.txt";
    //char *train_images = "/home/pjreddie/data/voc/test/2007_trainval.txt";
    char *backup_directory = "/home/pjreddie/backup/";
    srand(time(0));
    data_seed = time(0);
@@ -65,27 +63,21 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    int imgs = 128;
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = net.batch*net.subdivisions;
    int i = *net.seen/imgs;
    char **paths;
    list *plist = get_paths(train_images);
    int N = plist->size;
    paths = (char **)list_to_array(plist);
    if(i*imgs > N*80){
        net.layers[net.n-1].objectness = 0;
        net.layers[net.n-1].joint = 1;
    }
    if(i*imgs > N*120){
        net.layers[net.n-1].rescore = 1;
    }
    data train, buffer;
    detection_layer layer = get_network_detection_layer(net);
    int classes = layer.classes;
    int background = layer.objectness;
    int side = sqrt(get_detection_layer_locations(layer));
    layer l = net.layers[net.n - 1];
    int side = l.side;
    int classes = l.classes;
    float jitter = l.jitter;
    list *plist = get_paths(train_images);
    //int N = plist->size;
    char **paths = (char **)list_to_array(plist);
    load_args args = {0};
    args.w = net.w;
@@ -94,13 +86,14 @@
    args.n = imgs;
    args.m = plist->size;
    args.classes = classes;
    args.jitter = jitter;
    args.num_boxes = side;
    args.background = background;
    args.d = &buffer;
    args.type = DETECTION_DATA;
    args.type = REGION_DATA;
    pthread_t load_thread = load_data_in_thread(args);
    clock_t time;
    //while(i*imgs < N*120){
    while(get_current_batch(net) < net.max_batches){
        i += 1;
        time=clock();
@@ -109,36 +102,21 @@
        load_thread = load_data_in_thread(args);
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        /*
           image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
           image copy = copy_image(im);
           draw_yolo(copy, train.y.vals[113], 7, "truth");
           cvWaitKey(0);
           free_image(copy);
         */
        time=clock();
        float loss = train_network(net, train);
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %f rate, %d images, epoch: %f\n", get_current_batch(net), loss, avg_loss, sec(clock()-time), get_current_rate(net), *net.seen, (float)*net.seen/N);
        if((i-1)*imgs <= 80*N && i*imgs > N*80){
            fprintf(stderr, "Second stage done.\n");
            char buff[256];
            sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
            save_weights(net, buff);
            net.layers[net.n-1].joint = 1;
            net.layers[net.n-1].objectness = 0;
            background = 0;
            pthread_join(load_thread, 0);
            free_data(buffer);
            args.background = background;
            load_thread = load_data_in_thread(args);
        }
        if((i-1)*imgs <= 120*N && i*imgs > N*120){
            fprintf(stderr, "Third stage done.\n");
            char buff[256];
            sprintf(buff, "%s/%s_final.weights", backup_directory, base);
            net.layers[net.n-1].rescore = 1;
            save_weights(net, buff);
        }
        printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
        if(i%1000==0){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -147,36 +125,42 @@
        free_data(train);
    }
    char buff[256];
    sprintf(buff, "%s/%s_rescore.weights", backup_directory, base);
    sprintf(buff, "%s/%s_final.weights", backup_directory, base);
    save_weights(net, buff);
}
void convert_yolo_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness)
{
    int i,j;
    int per_box = 4+classes+(background || objectness);
    for (i = 0; i < num_boxes*num_boxes; ++i){
        float scale = 1;
        if(objectness) scale = 1-predictions[i*per_box];
        int offset = i*per_box+(background||objectness);
        for(j = 0; j < classes; ++j){
            float prob = scale*predictions[offset+j];
            probs[i][j] = (prob > thresh) ? prob : 0;
    int i,j,n;
    //int per_cell = 5*num+classes;
    for (i = 0; i < side*side; ++i){
        int row = i / side;
        int col = i % side;
        for(n = 0; n < num; ++n){
            int index = i*num + n;
            int p_index = side*side*classes + i*num + n;
            float scale = predictions[p_index];
            int box_index = side*side*(classes + num) + (i*num + n)*4;
            boxes[index].x = (predictions[box_index + 0] + col) / side * w;
            boxes[index].y = (predictions[box_index + 1] + row) / side * h;
            boxes[index].w = pow(predictions[box_index + 2], (square?2:1)) * w;
            boxes[index].h = pow(predictions[box_index + 3], (square?2:1)) * h;
            for(j = 0; j < classes; ++j){
                int class_index = i*classes;
                float prob = scale*predictions[class_index+j];
                probs[index][j] = (prob > thresh) ? prob : 0;
            }
            if(only_objectness){
                probs[index][0] = scale;
            }
        }
        int row = i / num_boxes;
        int col = i % num_boxes;
        offset += classes;
        boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
        boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
        boxes[i].w = pow(predictions[offset + 2], 2) * w;
        boxes[i].h = pow(predictions[offset + 3], 2) * h;
    }
}
void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
{
    int i, j;
    for(i = 0; i < num_boxes*num_boxes; ++i){
    for(i = 0; i < total; ++i){
        float xmin = boxes[i].x - boxes[i].w/2.;
        float xmax = boxes[i].x + boxes[i].w/2.;
        float ymin = boxes[i].y - boxes[i].h/2.;
@@ -201,29 +185,33 @@
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    detection_layer layer = get_network_detection_layer(net);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    char *base = "results/comp4_det_test_";
    //base = "/home/pjreddie/comp4_det_test_";
    //list *plist = get_paths("/home/pjreddie/data/people-art/test.txt");
    //list *plist = get_paths("/home/pjreddie/data/cubist/test.txt");
    list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
    //list *plist = get_paths("/home/pjreddie/data/voc/test_2012.txt");
    char **paths = (char **)list_to_array(plist);
    int classes = layer.classes;
    int objectness = layer.objectness;
    int background = layer.background;
    int num_boxes = sqrt(get_detection_layer_locations(layer));
    layer l = net.layers[net.n-1];
    int classes = l.classes;
    int square = l.sqrt;
    int side = l.side;
    int j;
    FILE **fps = calloc(classes, sizeof(FILE *));
    for(j = 0; j < classes; ++j){
        char buff[1024];
        snprintf(buff, 1024, "%s%s.txt", base, voc_class_names[j]);
        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
        fps[j] = fopen(buff, "w");
    }
    box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
    float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
    for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
    box *boxes = calloc(side*side*l.n, sizeof(box));
    float **probs = calloc(side*side*l.n, sizeof(float *));
    for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
    int m = plist->size;
    int i=0;
@@ -233,7 +221,7 @@
    int nms = 1;
    float iou_thresh = .5;
    int nthreads = 8;
    int nthreads = 2;
    image *val = calloc(nthreads, sizeof(image));
    image *val_resized = calloc(nthreads, sizeof(image));
    image *buf = calloc(nthreads, sizeof(image));
@@ -272,9 +260,9 @@
            float *predictions = network_predict(net, X);
            int w = val[t].w;
            int h = val[t].h;
            convert_yolo_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
            if (nms) do_nms(boxes, probs, num_boxes*num_boxes, classes, iou_thresh);
            print_yolo_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
            convert_yolo_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0);
            if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, iou_thresh);
            print_yolo_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
            free(id);
            free_image(val[t]);
            free_image(val_resized[t]);
@@ -283,6 +271,93 @@
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
void validate_yolo_recall(char *cfgfile, char *weightfile)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    srand(time(0));
    char *base = "results/comp4_det_test_";
    list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
    char **paths = (char **)list_to_array(plist);
    layer l = net.layers[net.n-1];
    int classes = l.classes;
    int square = l.sqrt;
    int side = l.side;
    int j, k;
    FILE **fps = calloc(classes, sizeof(FILE *));
    for(j = 0; j < classes; ++j){
        char buff[1024];
        snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
        fps[j] = fopen(buff, "w");
    }
    box *boxes = calloc(side*side*l.n, sizeof(box));
    float **probs = calloc(side*side*l.n, sizeof(float *));
    for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
    int m = plist->size;
    int i=0;
    float thresh = .001;
    int nms = 0;
    float iou_thresh = .5;
    float nms_thresh = .5;
    int total = 0;
    int correct = 0;
    int proposals = 0;
    float avg_iou = 0;
    for(i = 0; i < m; ++i){
        char *path = paths[i];
        image orig = load_image_color(path, 0, 0);
        image sized = resize_image(orig, net.w, net.h);
        char *id = basecfg(path);
        float *predictions = network_predict(net, sized.data);
        convert_yolo_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
        if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
        char *labelpath = find_replace(path, "images", "labels");
        labelpath = find_replace(labelpath, "JPEGImages", "labels");
        labelpath = find_replace(labelpath, ".jpg", ".txt");
        labelpath = find_replace(labelpath, ".JPEG", ".txt");
        int num_labels = 0;
        box_label *truth = read_boxes(labelpath, &num_labels);
        for(k = 0; k < side*side*l.n; ++k){
            if(probs[k][0] > thresh){
                ++proposals;
            }
        }
        for (j = 0; j < num_labels; ++j) {
            ++total;
            box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
            float best_iou = 0;
            for(k = 0; k < side*side*l.n; ++k){
                float iou = box_iou(boxes[k], t);
                if(probs[k][0] > thresh && iou > best_iou){
                    best_iou = iou;
                }
            }
            avg_iou += best_iou;
            if(best_iou > iou_thresh){
                ++correct;
            }
        }
        fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
        free(id);
        free_image(orig);
        free_image(sized);
    }
}
void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
{
@@ -290,12 +365,18 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    detection_layer layer = get_network_detection_layer(net);
    detection_layer l = net.layers[net.n-1];
    set_batch_network(&net, 1);
    srand(2222222);
    clock_t time;
    char buff[256];
    char *input = buff;
    int j;
    float nms=.5;
    printf("%d %d %d", l.side, l.n, l.classes);
    box *boxes = calloc(l.side*l.side*l.n, sizeof(box));
    float **probs = calloc(l.side*l.side*l.n, sizeof(float *));
    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
    while(1){
        if(filename){
            strncpy(input, filename, 256);
@@ -312,7 +393,11 @@
        time=clock();
        float *predictions = network_predict(net, X);
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        draw_yolo(im, predictions, 7, layer.objectness, "predictions", thresh);
        convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
        draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
        show_image(sized, "resized");
        free_image(im);
        free_image(sized);
#ifdef OPENCV
@@ -323,6 +408,47 @@
    }
}
/*
#ifdef OPENCV
image ipl_to_image(IplImage* src);
#include "opencv2/highgui/highgui_c.h"
#include "opencv2/imgproc/imgproc_c.h"
void demo_swag(char *cfgfile, char *weightfile, float thresh)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
detection_layer layer = net.layers[net.n-1];
CvCapture *capture = cvCaptureFromCAM(-1);
set_batch_network(&net, 1);
srand(2222222);
while(1){
IplImage* frame = cvQueryFrame(capture);
image im = ipl_to_image(frame);
cvReleaseImage(&frame);
rgbgr_image(im);
image sized = resize_image(im, net.w, net.h);
float *X = sized.data;
float *predictions = network_predict(net, X);
draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh);
free_image(im);
free_image(sized);
cvWaitKey(10);
}
}
#else
void demo_swag(char *cfgfile, char *weightfile, float thresh){}
#endif
 */
void demo_yolo(char *cfgfile, char *weightfile, float thresh);
#ifndef GPU
void demo_yolo(char *cfgfile, char *weightfile, float thresh){}
#endif
void run_yolo(int argc, char **argv)
{
    float thresh = find_float_arg(argc, argv, "-thresh", .2);
@@ -337,4 +463,6 @@
    if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
    else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
    else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
    else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
    else if(0==strcmp(argv[2], "demo")) demo_yolo(cfg, weights, thresh);
}
src/yolo_kernels.cu
File was renamed from src/swag_kernels.cu
@@ -1,6 +1,5 @@
extern "C" {
#include "network.h"
#include "region_layer.h"
#include "detection_layer.h"
#include "cost_layer.h"
#include "utils.h"
@@ -13,16 +12,16 @@
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
extern "C" image ipl_to_image(IplImage* src);
extern "C" void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
extern "C" void draw_swag(image im, int num, float thresh, box *boxes, float **probs, char *label);
extern "C" void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label);
extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh)
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    region_layer l = net.layers[net.n-1];
    detection_layer l = net.layers[net.n-1];
    cv::VideoCapture cap(0);
    set_batch_network(&net, 1);
@@ -43,12 +42,12 @@
        image sized = resize_image(im, net.w, net.h);
        float *X = sized.data;
        float *predictions = network_predict(net, X);
        convert_swag_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms);
        printf("\033[2J");
        printf("\033[1;1H");
        printf("\nObjects:\n\n");
        draw_swag(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
        draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
        free_image(im);
        free_image(sized);
@@ -56,6 +55,6 @@
    }
}
#else
extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh){}
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){}
#endif