Joseph Redmon
2015-03-04 fb9e0fe33681280112e4e33939c5844dba994dca
Big changes to detection
17 files modified
425 ■■■■■ changed files
.gitignore 6 ●●●●● patch | view | raw | blame | history
Makefile 4 ●●●● patch | view | raw | blame | history
src/cost_layer.c 33 ●●●●● patch | view | raw | blame | history
src/cost_layer.h 4 ●●● patch | view | raw | blame | history
src/cuda.c 2 ●●●●● patch | view | raw | blame | history
src/darknet.c 60 ●●●●● patch | view | raw | blame | history
src/data.c 11 ●●●● patch | view | raw | blame | history
src/detection_layer.c 145 ●●●●● patch | view | raw | blame | history
src/detection_layer.h 46 ●●●●● patch | view | raw | blame | history
src/image.c 2 ●●● patch | view | raw | blame | history
src/network.c 35 ●●●● patch | view | raw | blame | history
src/network.h 3 ●●●● patch | view | raw | blame | history
src/network_kernels.cu 28 ●●●●● patch | view | raw | blame | history
src/option_list.c 7 ●●●●● patch | view | raw | blame | history
src/option_list.h 1 ●●●● patch | view | raw | blame | history
src/parser.c 37 ●●●●● patch | view | raw | blame | history
src/softmax_layer.h 1 ●●●● patch | view | raw | blame | history
.gitignore
@@ -2,12 +2,18 @@
*.dSYM
*.csv
*.out
*.png
*.sh
mnist/
data/
caffe/
grasp/
images/
opencv/
convnet/
decaf/
submission/
cfg/
darknet
# OS Generated #
Makefile
@@ -9,7 +9,7 @@
CC=gcc
NVCC=nvcc
OPTS=-O3
LDFLAGS=`pkg-config --libs opencv` -lm -pthread
LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors
@@ -25,7 +25,7 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o normalization_layer.o parser.o option_list.o darknet.o detection_layer.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
endif
src/cost_layer.c
@@ -10,7 +10,6 @@
COST_TYPE get_cost_type(char *s)
{
    if (strcmp(s, "sse")==0) return SSE;
    if (strcmp(s, "detection")==0) return DETECTION;
    fprintf(stderr, "Couldn't find activation function %s, going with SSE\n", s);
    return SSE;
}
@@ -20,8 +19,6 @@
    switch(a){
        case SSE:
            return "sse";
        case DETECTION:
            return "detection";
    }
    return "sse";
}
@@ -41,17 +38,20 @@
    return layer;
}
void pull_cost_layer(cost_layer layer)
{
    cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
}
void push_cost_layer(cost_layer layer)
{
    cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
}
void forward_cost_layer(cost_layer layer, float *input, float *truth)
{
    if (!truth) return;
    copy_cpu(layer.batch*layer.inputs, truth, 1, layer.delta, 1);
    axpy_cpu(layer.batch*layer.inputs, -1, input, 1, layer.delta, 1);
    if(layer.type == DETECTION){
        int i;
        for(i = 0; i < layer.batch*layer.inputs; ++i){
            if((i%25) && !truth[(i/25)*25]) layer.delta[i] = 0;
        }
    }
    *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
    //printf("cost: %f\n", *layer.output);
}
@@ -67,13 +67,20 @@
{
    if (!truth) return;
    /*
    float *in = calloc(layer.inputs*layer.batch, sizeof(float));
    float *t = calloc(layer.inputs*layer.batch, sizeof(float));
    cuda_pull_array(input, in, layer.batch*layer.inputs);
    cuda_pull_array(truth, t, layer.batch*layer.inputs);
    forward_cost_layer(layer, in, t);
    cuda_push_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
    free(in);
    free(t);
    */
    copy_ongpu(layer.batch*layer.inputs, truth, 1, layer.delta_gpu, 1);
    axpy_ongpu(layer.batch*layer.inputs, -1, input, 1, layer.delta_gpu, 1);
    if(layer.type==DETECTION){
        mask_ongpu(layer.inputs*layer.batch, layer.delta_gpu, truth, 25);
    }
    cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*layer.inputs);
    *(layer.output) = dot_cpu(layer.batch*layer.inputs, layer.delta, 1, layer.delta, 1);
    //printf("cost: %f\n", *layer.output);
src/cost_layer.h
@@ -2,12 +2,14 @@
#define COST_LAYER_H
typedef enum{
    SSE, DETECTION
    SSE
} COST_TYPE;
typedef struct {
    int inputs;
    int batch;
    int coords;
    int classes;
    float *delta;
    float *output;
    COST_TYPE type;
src/cuda.c
@@ -5,6 +5,7 @@
#include "cuda.h"
#include "utils.h"
#include "blas.h"
#include "assert.h"
#include <stdlib.h>
@@ -15,6 +16,7 @@
        const char *s = cudaGetErrorString(status);
        char buffer[256];
        printf("CUDA Error: %s\n", s);
        assert(0);
        snprintf(buffer, 256, "CUDA Error: %s", s);
        error(buffer);
    } 
src/darknet.c
@@ -36,42 +36,30 @@
void draw_detection(image im, float *box, int side)
{
    int classes = 20;
    int elems = 4+classes+1;
    int elems = 4+classes;
    int j;
    int r, c;
    float amount[AMNT] = {0};
    for(r = 0; r < side*side; ++r){
        float val = box[r*elems];
        for(j = 0; j < AMNT; ++j){
            if(val > amount[j]) {
                float swap = val;
                val = amount[j];
                amount[j] = swap;
            }
        }
    }
    float smallest = amount[AMNT-1];
    for(r = 0; r < side; ++r){
        for(c = 0; c < side; ++c){
            j = (r*side + c) * elems;
            //printf("%d\n", j);
            //printf("Prob: %f\n", box[j]);
            if(box[j] >= smallest){
                int class = max_index(box+j+1, classes);
                int z;
                for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+1+z], class_names[z]);
                printf("%f %s\n", box[j+1+class], class_names[class]);
            int class = max_index(box+j, classes);
            if(box[j+class] > .02 || 1){
                //int z;
                //for(z = 0; z < classes; ++z) printf("%f %s\n", box[j+z], class_names[z]);
                printf("%f %s\n", box[j+class], class_names[class]);
                float red = get_color(0,class,classes);
                float green = get_color(1,class,classes);
                float blue = get_color(2,class,classes);
                j += classes;
                int d = im.w/side;
                int y = r*d+box[j+1]*d;
                int x = c*d+box[j+2]*d;
                int h = box[j+3]*im.h;
                int w = box[j+4]*im.w;
                int y = r*d+box[j]*d;
                int x = c*d+box[j+1]*d;
                int h = box[j+2]*im.h;
                int w = box[j+3]*im.w;
                draw_box(im, x-w/2, y-h/2, x+w/2, y+h/2,red,green,blue);
            }
        }
@@ -117,14 +105,15 @@
    data train, buffer;
    int im_dim = 512;
    int jitter = 64;
    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
    int classes = 21;
    pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
    clock_t time;
    while(1){
        i += 1;
        time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
        load_thread = load_data_detection_thread(imgs, paths, plist->size, 20, im_dim, im_dim, 7, 7, jitter, &buffer);
        load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, im_dim, im_dim, 7, 7, jitter, &buffer);
/*
        image im = float_to_image(im_dim - jitter, im_dim-jitter, 3, train.X.vals[0]);
@@ -139,7 +128,7 @@
        net.seen += imgs;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
        if(i%800==0){
        if(i%100==0){
            char buff[256];
            sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
            save_weights(net, buff);
@@ -161,7 +150,7 @@
    char **paths = (char **)list_to_array(plist);
    int num_output = 1225;
    int im_size = 448;
    int classes = 20;
    int classes = 21;
    int m = plist->size;
    int i = 0;
@@ -185,7 +174,7 @@
        matrix pred = network_predict_data(net, val);
        int j, k, class;
        for(j = 0; j < pred.rows; ++j){
            for(k = 0; k < pred.cols; k += classes+4+1){
            for(k = 0; k < pred.cols; k += classes+4){
                /*
                int z;
@@ -193,17 +182,16 @@
                printf("\n");
                */
                float p = pred.vals[j][k];
                //if (pred.vals[j][k] > .001){
                for(class = 0; class < classes; ++class){
                    int index = (k)/(classes+4+1);
                for(class = 0; class < classes-1; ++class){
                    int index = (k)/(classes+4);
                    int r = index/7;
                    int c = index%7;
                    float y = (r + pred.vals[j][k+1+classes])/7.;
                    float x = (c + pred.vals[j][k+2+classes])/7.;
                    float h = pred.vals[j][k+3+classes];
                    float w = pred.vals[j][k+4+classes];
                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, p*pred.vals[j][k+class+1], y, x, h, w);
                    float y = (r + pred.vals[j][k+0+classes])/7.;
                    float x = (c + pred.vals[j][k+1+classes])/7.;
                    float h = pred.vals[j][k+2+classes];
                    float w = pred.vals[j][k+3+classes];
                    printf("%d %d %f %f %f %f %f\n", (i-1)*m/splits + j, class, pred.vals[j][k+class], y, x, h, w);
                }
                //}
            }
@@ -462,7 +450,7 @@
    if(weightfile){
        load_weights(&net, weightfile);
    }
    int im_size = 224;
    int im_size = 448;
    set_batch_network(&net, 1);
    srand(2222222);
    clock_t time;
src/data.c
@@ -89,8 +89,7 @@
        float dw = (x - i*box_width)/box_width;
        float dh = (y - j*box_height)/box_height;
        //printf("%d %d %d %f %f\n", id, i, j, dh, dw);
        int index = (i+j*num_width)*(4+classes+1);
        truth[index++] = 1;
        int index = (i+j*num_width)*(4+classes);
        truth[index+id] = 1;
        index += classes;
        truth[index++] = dh;
@@ -98,6 +97,12 @@
        truth[index++] = h*(height+jitter)/height;
        truth[index++] = w*(width+jitter)/width;
    }
    int i, j;
    for(i = 0; i < num_height*num_width*(4+classes); i += 4+classes){
        int background = 1;
        for(j = i; j < i+classes; ++j) if (truth[j]) background = 0;
        truth[i+classes-1] = background;
    }
    fclose(file);
}
@@ -209,7 +214,7 @@
    data d;
    d.shallow = 0;
    d.X = load_image_paths(random_paths, n, h, w);
    int k = nh*nw*(4+classes+1);
    int k = nh*nw*(4+classes);
    d.y = make_matrix(n, k);
    for(i = 0; i < n; ++i){
        int dx = rand()%jitter;
src/detection_layer.c
@@ -1,72 +1,123 @@
int detection_out_height(detection_layer layer)
#include "detection_layer.h"
#include "activations.h"
#include "softmax_layer.h"
#include "blas.h"
#include "cuda.h"
#include <stdio.h>
#include <stdlib.h>
int get_detection_layer_locations(detection_layer layer)
{
    return layer.size + layer.h*layer.stride;
    return layer.inputs / (layer.classes+layer.coords+layer.rescore);
}
int detection_out_width(detection_layer layer)
int get_detection_layer_output_size(detection_layer layer)
{
    return layer.size + layer.w*layer.stride;
    return get_detection_layer_locations(layer)*(layer.classes+layer.coords);
}
detection_layer *make_detection_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore)
{
    int i;
    size = 2*(size/2)+1; //HA! And you thought you'd use an even sized filter...
    detection_layer *layer = calloc(1, sizeof(detection_layer));
    layer->h = h;
    layer->w = w;
    layer->c = c;
    layer->n = n;
    layer->batch = batch;
    layer->stride = stride;
    layer->size = size;
    assert(c%n == 0);
    layer->inputs = inputs;
    layer->classes = classes;
    layer->coords = coords;
    layer->rescore = rescore;
    int outputs = get_detection_layer_output_size(*layer);
    layer->output = calloc(batch*outputs, sizeof(float));
    layer->delta = calloc(batch*outputs, sizeof(float));
    #ifdef GPU
    layer->output_gpu = cuda_make_array(0, batch*outputs);
    layer->delta_gpu = cuda_make_array(0, batch*outputs);
    #endif
    layer->filters = calloc(c*size*size, sizeof(float));
    layer->filter_updates = calloc(c*size*size, sizeof(float));
    layer->filter_momentum = calloc(c*size*size, sizeof(float));
    float scale = 1./(size*size*c);
    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform());
    int out_h = detection_out_height(*layer);
    int out_w = detection_out_width(*layer);
    layer->output = calloc(layer->batch * out_h * out_w * n, sizeof(float));
    layer->delta  = calloc(layer->batch * out_h * out_w * n, sizeof(float));
    layer->activation = activation;
    fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
    fprintf(stderr, "Detection Layer\n");
    srand(0);
    return layer;
}
void forward_detection_layer(const detection_layer layer, float *in)
void forward_detection_layer(const detection_layer layer, float *in, float *truth)
{
    int out_h = detection_out_height(layer);
    int out_w = detection_out_width(layer);
    int i,j,fh, fw,c;
    memset(layer.output, 0, layer->batch*layer->n*out_h*out_w*sizeof(float));
    for(c = 0; c < layer.c; ++c){
        for(i = 0; i < layer.h; ++i){
            for(j = 0; j < layer.w; ++j){
                float val = layer->input[j+(i + c*layer.h)*layer.w];
                for(fh = 0; fh < layer.size; ++fh){
                    for(fw = 0; fw < layer.size; ++fw){
                        int h = i*layer.stride + fh;
                        int w = j*layer.stride + fw;
                        layer.output[w+(h+c/n*out_h)*out_w] += val*layer->filters[fw+(fh+c*layer.size)*layer.size];
    int in_i = 0;
    int out_i = 0;
    int locations = get_detection_layer_locations(layer);
    int i,j;
    for(i = 0; i < layer.batch*locations; ++i){
        int mask = (!truth || !truth[out_i + layer.classes - 1]);
        float scale = 1;
        if(layer.rescore) scale = in[in_i++];
        for(j = 0; j < layer.classes; ++j){
            layer.output[out_i++] = scale*in[in_i++];
                    }
        softmax_array(layer.output + out_i - layer.classes, layer.classes, layer.output + out_i - layer.classes);
        activate_array(layer.output+out_i, layer.coords, SIGMOID);
        for(j = 0; j < layer.coords; ++j){
            layer.output[out_i++] = mask*in[in_i++];
                }
            }
        }
        //printf("%d\n", mask);
        //for(j = 0; j < layer.classes+layer.coords; ++j) printf("%f ", layer.output[i*(layer.classes+layer.coords)+j]);
        //printf ("\n");
    }
}
void backward_detection_layer(const detection_layer layer, float *delta)
void backward_detection_layer(const detection_layer layer, float *in, float *delta)
{
    int locations = get_detection_layer_locations(layer);
    int i,j;
    int in_i = 0;
    int out_i = 0;
    for(i = 0; i < layer.batch*locations; ++i){
        float scale = 1;
        float latent_delta = 0;
        if(layer.rescore) scale = in[in_i++];
        for(j = 0; j < layer.classes; ++j){
            latent_delta += in[in_i]*layer.delta[out_i];
            delta[in_i++] = scale*layer.delta[out_i++];
}
        for(j = 0; j < layer.coords; ++j){
            delta[in_i++] = layer.delta[out_i++];
        }
        gradient_array(in + in_i - layer.coords, layer.coords, SIGMOID, layer.delta + out_i - layer.coords);
        if(layer.rescore) delta[in_i-layer.coords-layer.classes-layer.rescore] = latent_delta;
    }
}
#ifdef GPU
void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth)
{
    int outputs = get_detection_layer_output_size(layer);
    float *in_cpu = calloc(layer.batch*layer.inputs, sizeof(float));
    float *truth_cpu = 0;
    if(truth){
        truth_cpu = calloc(layer.batch*outputs, sizeof(float));
        cuda_pull_array(truth, truth_cpu, layer.batch*outputs);
    }
    cuda_pull_array(in, in_cpu, layer.batch*layer.inputs);
    forward_detection_layer(layer, in_cpu, truth_cpu);
    cuda_push_array(layer.output_gpu, layer.output, layer.batch*outputs);
    free(in_cpu);
    if(truth_cpu) free(truth_cpu);
}
void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta)
{
    int outputs = get_detection_layer_output_size(layer);
    float *in_cpu =    calloc(layer.batch*layer.inputs, sizeof(float));
    float *delta_cpu = calloc(layer.batch*layer.inputs, sizeof(float));
    cuda_pull_array(in, in_cpu, layer.batch*layer.inputs);
    cuda_pull_array(layer.delta_gpu, layer.delta, layer.batch*outputs);
    backward_detection_layer(layer, in_cpu, delta_cpu);
    cuda_push_array(delta, delta_cpu, layer.batch*layer.inputs);
    free(in_cpu);
    free(delta_cpu);
}
#endif
src/detection_layer.h
@@ -3,38 +3,26 @@
typedef struct {
    int batch;
    int h,w,c;
    int n;
    int size;
    int stride;
    float *filters;
    float *filter_updates;
    float *filter_momentum;
    float *biases;
    float *bias_updates;
    float *bias_momentum;
    float *col_image;
    float *delta;
    int inputs;
    int classes;
    int coords;
    int rescore;
    float *output;
    float *delta;
    #ifdef GPU
    float * output_gpu;
    float * delta_gpu;
    #endif
} detection_layer;
detection_layer *make_detection_layer(int batch, int inputs, int classes, int coords, int rescore);
void forward_detection_layer(const detection_layer layer, float *in, float *truth);
void backward_detection_layer(const detection_layer layer, float *in, float *delta);
int get_detection_layer_output_size(detection_layer layer);
    #ifdef GPU
    cl_mem filters_cl;
    cl_mem filter_updates_cl;
    cl_mem filter_momentum_cl;
    cl_mem biases_cl;
    cl_mem bias_updates_cl;
    cl_mem bias_momentum_cl;
    cl_mem col_image_cl;
    cl_mem delta_cl;
    cl_mem output_cl;
void forward_detection_layer_gpu(const detection_layer layer, float *in, float *truth);
void backward_detection_layer_gpu(detection_layer layer, float *in, float *delta);
    #endif
    ACTIVATION activation;
} convolutional_layer;
#endif
src/image.c
@@ -13,7 +13,7 @@
    int j = ceil(ratio);
    ratio -= i;
    float r = (1-ratio) * colors[i][c] + ratio*colors[j][c];
    printf("%f\n", r);
    //printf("%f\n", r);
    return r;
}
src/network.c
@@ -9,6 +9,7 @@
#include "connected_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
#include "maxpool_layer.h"
#include "cost_layer.h"
#include "normalization_layer.h"
@@ -29,6 +30,8 @@
            return "maxpool";
        case SOFTMAX:
            return "softmax";
        case DETECTION:
            return "detection";
        case NORMALIZATION:
            return "normalization";
        case DROPOUT:
@@ -76,6 +79,11 @@
            forward_deconvolutional_layer(layer, input);
            input = layer.output;
        }
        else if(net.types[i] == DETECTION){
            detection_layer layer = *(detection_layer *)net.layers[i];
            forward_detection_layer(layer, input, truth);
            input = layer.output;
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            forward_connected_layer(layer, input);
@@ -152,6 +160,9 @@
    } else if(net.types[i] == MAXPOOL){
        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
        return layer.output;
    } else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        return layer.output;
    } else if(net.types[i] == SOFTMAX){
        softmax_layer layer = *(softmax_layer *)net.layers[i];
        return layer.output;
@@ -193,6 +204,9 @@
    } else if(net.types[i] == SOFTMAX){
        softmax_layer layer = *(softmax_layer *)net.layers[i];
        return layer.delta;
    } else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        return layer.delta;
    } else if(net.types[i] == DROPOUT){
        if(i == 0) return 0;
        return get_network_delta_layer(net, i-1);
@@ -243,7 +257,7 @@
    return max_index(out, k);
}
void backward_network(network net, float *input)
void backward_network(network net, float *input, float *truth)
{
    int i;
    float *prev_input;
@@ -272,6 +286,10 @@
            dropout_layer layer = *(dropout_layer *)net.layers[i];
            backward_dropout_layer(layer, prev_delta);
        }
        else if(net.types[i] == DETECTION){
            detection_layer layer = *(detection_layer *)net.layers[i];
            backward_detection_layer(layer, prev_input, prev_delta);
        }
        else if(net.types[i] == NORMALIZATION){
            normalization_layer layer = *(normalization_layer *)net.layers[i];
            if(i != 0) backward_normalization_layer(layer, prev_input, prev_delta);
@@ -297,7 +315,7 @@
    if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
    #endif
    forward_network(net, x, y, 1);
    backward_network(net, x);
    backward_network(net, x, y);
    float error = get_network_cost(net);
    update_network(net);
    return error;
@@ -351,7 +369,7 @@
            float *x = d.X.vals[index];
            float *y = d.y.vals[index];
            forward_network(net, x, y, 1);
            backward_network(net, x);
            backward_network(net, x, y);
            sum += get_network_cost(net);
        }
        update_network(net);
@@ -381,7 +399,6 @@
    }
}
void set_batch_network(network *net, int b)
{
    net->batch = b;
@@ -404,6 +421,9 @@
        } else if(net->types[i] == DROPOUT){
            dropout_layer *layer = (dropout_layer *) net->layers[i];
            layer->batch = b;
        } else if(net->types[i] == DETECTION){
            detection_layer *layer = (detection_layer *) net->layers[i];
            layer->batch = b;
        }
        else if(net->types[i] == FREEWEIGHT){
            freeweight_layer *layer = (freeweight_layer *) net->layers[i];
@@ -445,6 +465,9 @@
    } else if(net.types[i] == DROPOUT){
        dropout_layer layer = *(dropout_layer *) net.layers[i];
        return layer.inputs;
    } else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *) net.layers[i];
        return layer.inputs;
    } else if(net.types[i] == CROP){
        crop_layer layer = *(crop_layer *) net.layers[i];
        return layer.c*layer.h*layer.w;
@@ -473,6 +496,10 @@
        image output = get_deconvolutional_image(layer);
        return output.h*output.w*output.c;
    }
    else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        return get_detection_layer_output_size(layer);
    }
    else if(net.types[i] == MAXPOOL){
        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
        image output = get_maxpool_image(layer);
src/network.h
@@ -11,6 +11,7 @@
    CONNECTED,
    MAXPOOL,
    SOFTMAX,
    DETECTION,
    NORMALIZATION,
    DROPOUT,
    FREEWEIGHT,
@@ -48,7 +49,7 @@
network make_network(int n, int batch);
void forward_network(network net, float *input, float *truth, int train);
void backward_network(network net, float *input);
void backward_network(network net, float *input, float *truth);
void update_network(network net);
float train_network(network net, data d);
src/network_kernels.cu
@@ -9,6 +9,7 @@
#include "crop_layer.h"
#include "connected_layer.h"
#include "detection_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
@@ -47,6 +48,11 @@
            forward_connected_layer_gpu(layer, input);
            input = layer.output_gpu;
        }
        else if(net.types[i] == DETECTION){
            detection_layer layer = *(detection_layer *)net.layers[i];
            forward_detection_layer_gpu(layer, input, truth);
            input = layer.output_gpu;
        }
        else if(net.types[i] == MAXPOOL){
            maxpool_layer layer = *(maxpool_layer *)net.layers[i];
            forward_maxpool_layer_gpu(layer, input);
@@ -73,7 +79,7 @@
    }
}
void backward_network_gpu(network net, float * input)
void backward_network_gpu(network net, float * input, float *truth)
{
    int i;
    float * prev_input;
@@ -103,6 +109,10 @@
            connected_layer layer = *(connected_layer *)net.layers[i];
            backward_connected_layer_gpu(layer, prev_input, prev_delta);
        }
        else if(net.types[i] == DETECTION){
            detection_layer layer = *(detection_layer *)net.layers[i];
            backward_detection_layer_gpu(layer, prev_input, prev_delta);
        }
        else if(net.types[i] == MAXPOOL){
            maxpool_layer layer = *(maxpool_layer *)net.layers[i];
            backward_maxpool_layer_gpu(layer, prev_delta);
@@ -148,6 +158,10 @@
        deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
        return layer.output_gpu;
    }
    else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        return layer.output_gpu;
    }
    else if(net.types[i] == CONNECTED){
        connected_layer layer = *(connected_layer *)net.layers[i];
        return layer.output_gpu;
@@ -176,6 +190,10 @@
        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
        return layer.delta_gpu;
    }
    else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        return layer.delta_gpu;
    }
    else if(net.types[i] == DECONVOLUTIONAL){
        deconvolutional_layer layer = *(deconvolutional_layer *)net.layers[i];
        return layer.delta_gpu;
@@ -215,7 +233,7 @@
    forward_network_gpu(net, *net.input_gpu, *net.truth_gpu, 1);
  //printf("forw %f\n", sec(clock() - time));
  //time = clock();
    backward_network_gpu(net, *net.input_gpu);
    backward_network_gpu(net, *net.input_gpu, *net.truth_gpu);
  //printf("back %f\n", sec(clock() - time));
  //time = clock();
    update_network_gpu(net);
@@ -244,6 +262,12 @@
        cuda_pull_array(layer.output_gpu, layer.output, layer.outputs*layer.batch);
        return layer.output;
    }
    else if(net.types[i] == DETECTION){
        detection_layer layer = *(detection_layer *)net.layers[i];
        int outputs = get_detection_layer_output_size(layer);
        cuda_pull_array(layer.output_gpu, layer.output, outputs*layer.batch);
        return layer.output;
    }
    else if(net.types[i] == MAXPOOL){
        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
        return layer.output;
src/option_list.c
@@ -53,6 +53,13 @@
    return def;
}
int option_find_int_quiet(list *l, char *key, int def)
{
    char *v = option_find(l, key);
    if(v) return atoi(v);
    return def;
}
float option_find_float_quiet(list *l, char *key, float def)
{
    char *v = option_find(l, key);
src/option_list.h
@@ -13,6 +13,7 @@
char *option_find(list *l, char *key);
char *option_find_str(list *l, char *key, char *def);
int option_find_int(list *l, char *key, int def);
int option_find_int_quiet(list *l, char *key, int def);
float option_find_float(list *l, char *key, float def);
float option_find_float_quiet(list *l, char *key, float def);
void option_unused(list *l);
src/parser.c
@@ -13,6 +13,7 @@
#include "normalization_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "freeweight_layer.h"
#include "list.h"
#include "option_list.h"
@@ -32,6 +33,7 @@
int is_softmax(section *s);
int is_crop(section *s);
int is_cost(section *s);
int is_detection(section *s);
int is_normalization(section *s);
list *read_cfg(char *filename);
@@ -204,6 +206,24 @@
    return layer;
}
detection_layer *parse_detection(list *options, network *net, int count)
{
    int input;
    if(count == 0){
        input = option_find_int(options, "input",1);
        net->batch = option_find_int(options, "batch",1);
        net->seen = option_find_int(options, "seen",0);
    }else{
        input =  get_network_output_size_layer(*net, count-1);
    }
    int coords = option_find_int(options, "coords", 1);
    int classes = option_find_int(options, "classes", 1);
    int rescore = option_find_int(options, "rescore", 1);
    detection_layer *layer = make_detection_layer(net->batch, input, classes, coords, rescore);
    option_unused(options);
    return layer;
}
cost_layer *parse_cost(list *options, network *net, int count)
{
    int input;
@@ -368,6 +388,10 @@
            cost_layer *layer = parse_cost(options, &net, count);
            net.types[count] = COST;
            net.layers[count] = layer;
        }else if(is_detection(s)){
            detection_layer *layer = parse_detection(options, &net, count);
            net.types[count] = DETECTION;
            net.layers[count] = layer;
        }else if(is_softmax(s)){
            softmax_layer *layer = parse_softmax(options, &net, count);
            net.types[count] = SOFTMAX;
@@ -410,6 +434,10 @@
{
    return (strcmp(s->type, "[cost]")==0);
}
int is_detection(section *s)
{
    return (strcmp(s->type, "[detection]")==0);
}
int is_deconvolutional(section *s)
{
    return (strcmp(s->type, "[deconv]")==0
@@ -684,6 +712,13 @@
    fprintf(fp, "\n");
}
void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count)
{
    fprintf(fp, "[detection]\n");
    fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\n", l->classes, l->coords, l->rescore);
    fprintf(fp, "\n");
}
void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
{
    fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type));
@@ -815,6 +850,8 @@
            print_normalization_cfg(fp, (normalization_layer *)net.layers[i], net, i);
        else if(net.types[i] == SOFTMAX)
            print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i);
        else if(net.types[i] == DETECTION)
            print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i);
        else if(net.types[i] == COST)
            print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i);
    }
src/softmax_layer.h
@@ -13,6 +13,7 @@
    #endif
} softmax_layer;
void softmax_array(float *input, int n, float *output);
softmax_layer *make_softmax_layer(int batch, int groups, int inputs);
void forward_softmax_layer(const softmax_layer layer, float *input);
void backward_softmax_layer(const softmax_layer layer, float *delta);