hey
Joseph Redmon
2016-11-16 0d6b107ed20c22412ccf3a5056cffdb35bc25534
hey
22 files modified
432 ■■■■ changed files
src/batchnorm_layer.c 8 ●●●● patch | view | raw | blame | history
src/blas.c 10 ●●●● patch | view | raw | blame | history
src/blas.h 4 ●●● patch | view | raw | blame | history
src/blas_kernels.cu 31 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.c 8 ●●●●● patch | view | raw | blame | history
src/cuda.c 1 ●●●● patch | view | raw | blame | history
src/darknet.c 2 ●●● patch | view | raw | blame | history
src/data.c 10 ●●●●● patch | view | raw | blame | history
src/detector.c 88 ●●●● patch | view | raw | blame | history
src/layer.h 1 ●●●● patch | view | raw | blame | history
src/network.c 21 ●●●● patch | view | raw | blame | history
src/network_kernels.cu 3 ●●●● patch | view | raw | blame | history
src/parser.c 59 ●●●●● patch | view | raw | blame | history
src/region_layer.c 110 ●●●● patch | view | raw | blame | history
src/region_layer.h 1 ●●●● patch | view | raw | blame | history
src/reorg_layer.c 13 ●●●● patch | view | raw | blame | history
src/route_layer.c 34 ●●●●● patch | view | raw | blame | history
src/route_layer.h 1 ●●●● patch | view | raw | blame | history
src/tree.c 10 ●●●●● patch | view | raw | blame | history
src/tree.h 1 ●●●● patch | view | raw | blame | history
src/utils.c 15 ●●●●● patch | view | raw | blame | history
src/utils.h 1 ●●●● patch | view | raw | blame | history
src/batchnorm_layer.c
@@ -166,10 +166,10 @@
        fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
        fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
        scal_ongpu(l.out_c, .95, l.rolling_mean_gpu, 1);
        axpy_ongpu(l.out_c, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
        scal_ongpu(l.out_c, .95, l.rolling_variance_gpu, 1);
        axpy_ongpu(l.out_c, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
        scal_ongpu(l.out_c, .99, l.rolling_mean_gpu, 1);
        axpy_ongpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
        scal_ongpu(l.out_c, .99, l.rolling_variance_gpu, 1);
        axpy_ongpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
        copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
        normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
src/blas.c
@@ -6,7 +6,7 @@
#include <stdlib.h>
#include <string.h>
void reorg(float *x, int size, int layers, int batch, int forward)
void flatten(float *x, int size, int layers, int batch, int forward)
{
    float *swap = calloc(size*layers*batch, sizeof(float));
    int i,c,b;
@@ -189,12 +189,12 @@
        if(input[i] > largest) largest = input[i];
    }
    for(i = 0; i < n; ++i){
        sum += exp(input[i]/temp-largest/temp);
        float e = exp(input[i]/temp - largest/temp);
        sum += e;
        output[i] = e;
    }
    if(sum) sum = largest/temp+log(sum);
    else sum = largest-100;
    for(i = 0; i < n; ++i){
        output[i] = exp(input[i]/temp-sum);
        output[i] /= sum;
    }
}
src/blas.h
@@ -1,6 +1,6 @@
#ifndef BLAS_H
#define BLAS_H
void reorg(float *x, int size, int layers, int batch, int forward);
void flatten(float *x, int size, int layers, int batch, int forward);
void pm(int M, int N, float *A);
float *random_matrix(int rows, int cols);
void time_random_matrix(int TA, int TB, int m, int k, int n);
@@ -80,5 +80,7 @@
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out);
#endif
#endif
src/blas_kernels.cu
@@ -543,6 +543,30 @@
    check_error(cudaPeekAtLastError());
}
__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(i >= N) return;
    int in_s = i%spatial;
    i = i/spatial;
    int in_c = i%layers;
    i = i/layers;
    int b = i;
    int i1 = b*layers*spatial + in_c*spatial + in_s;
    int i2 = b*layers*spatial + in_s*layers +  in_c;
    if (forward) out[i2] = x[i1];
    else out[i1] = x[i2];
}
extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
{
    int size = spatial*batch*layers;
    flatten_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, spatial, layers, batch, forward, out);
    check_error(cudaPeekAtLastError());
}
extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
{
    int size = w*h*c*batch;
@@ -718,11 +742,12 @@
        largest = (val>largest) ? val : largest;
    }
    for(i = 0; i < n; ++i){
        sum += exp(input[i]/temp-largest/temp);
        float e = exp(input[i]/temp - largest/temp);
        sum += e;
        output[i] = e;
    }
    sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
    for(i = 0; i < n; ++i){
        output[i] = exp(input[i]/temp-sum);
        output[i] /= sum;
    }
}
src/convolutional_layer.c
@@ -368,6 +368,14 @@
    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
    if(l->batch_normalize){
        cuda_free(l->x_gpu);
        cuda_free(l->x_norm_gpu);
        l->x_gpu = cuda_make_array(l->output, l->batch*l->outputs);
        l->x_norm_gpu = cuda_make_array(l->output, l->batch*l->outputs);
    }
#ifdef CUDNN
    cudnn_convolutional_setup(l);
#endif
src/cuda.c
@@ -26,6 +26,7 @@
void check_error(cudaError_t status)
{
    //cudaDeviceSynchronize();
    cudaError_t status2 = cudaGetLastError();
    if (status != cudaSuccess)
    {   
src/darknet.c
@@ -127,7 +127,7 @@
    network net = parse_network_cfg(cfgfile);
    int oldn = net.layers[net.n - 2].n;
    int c = net.layers[net.n - 2].c;
    net.layers[net.n - 2].n = 7879;
    net.layers[net.n - 2].n = 9372;
    net.layers[net.n - 2].biases += 5;
    net.layers[net.n - 2].weights += 5*c;
    if(weightfile){
src/data.c
@@ -171,6 +171,13 @@
{
    int i;
    for(i = 0; i < n; ++i){
        if(boxes[i].x == 0 && boxes[i].y == 0) {
            boxes[i].x = 999999;
            boxes[i].y = 999999;
            boxes[i].w = 999999;
            boxes[i].h = 999999;
            continue;
        }
        boxes[i].left   = boxes[i].left  * sx - dx;
        boxes[i].right  = boxes[i].right * sx - dx;
        boxes[i].top    = boxes[i].top   * sy - dy;
@@ -289,6 +296,7 @@
    find_replace(path, "images", "labels", labelpath);
    find_replace(labelpath, "JPEGImages", "labels", labelpath);
    find_replace(labelpath, "raw", "labels", labelpath);
    find_replace(labelpath, ".jpg", ".txt", labelpath);
    find_replace(labelpath, ".png", ".txt", labelpath);
    find_replace(labelpath, ".JPG", ".txt", labelpath);
@@ -309,7 +317,7 @@
        h =  boxes[i].h;
        id = boxes[i].id;
        if (w < .01 || h < .01) continue;
        if ((w < .01 || h < .01)) continue;
        truth[i*5+0] = x;
        truth[i*5+1] = y;
src/detector.c
@@ -75,8 +75,27 @@
    pthread_t load_thread = load_data(args);
    clock_t time;
    int count = 0;
    //while(i*imgs < N*120){
    while(get_current_batch(net) < net.max_batches){
        if(l.random && count++%10 == 0){
            printf("Resizing\n");
            int dim = (rand() % 10 + 10) * 32;
            //int dim = (rand() % 4 + 16) * 32;
            printf("%d\n", dim);
            args.w = dim;
            args.h = dim;
            pthread_join(load_thread, 0);
            train = buffer;
            free_data(train);
            load_thread = load_data(args);
            for(i = 0; i < ngpus; ++i){
                resize_network(nets + i, dim, dim);
            }
            net = nets[0];
        }
        time=clock();
        pthread_join(load_thread, 0);
        train = buffer;
@@ -117,13 +136,15 @@
        i = get_current_batch(net);
        printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
        if(i%1000==0 || (i < 1000 && i%100 == 0)){
        if(i%100==0 || (i < 1000 && i%100 == 0)){
            if(ngpus != 1) sync_nets(nets, ngpus, 0);
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
            save_weights(net, buff);
        }
        free_data(train);
    }
    if(ngpus != 1) sync_nets(nets, ngpus, 0);
    char buff[256];
    sprintf(buff, "%s/%s_final.weights", backup_directory, base);
    save_weights(net, buff);
@@ -183,6 +204,29 @@
    }
}
void print_imagenet_detections(FILE *fp, int id, box *boxes, float **probs, int total, int classes, int w, int h, int *map)
{
    int i, j;
    for(i = 0; i < total; ++i){
        float xmin = boxes[i].x - boxes[i].w/2.;
        float xmax = boxes[i].x + boxes[i].w/2.;
        float ymin = boxes[i].y - boxes[i].h/2.;
        float ymax = boxes[i].y + boxes[i].h/2.;
        if (xmin < 0) xmin = 0;
        if (ymin < 0) ymin = 0;
        if (xmax > w) xmax = w;
        if (ymax > h) ymax = h;
        for(j = 0; j < classes; ++j){
            int class = j;
            if (map) class = map[j];
            if (probs[i][class]) fprintf(fp, "%d %d %f %f %f %f %f\n", id, j+1, probs[i][class],
                    xmin, ymin, xmax, ymax);
        }
    }
}
void validate_detector(char *datacfg, char *cfgfile, char *weightfile)
{
    list *options = read_data_cfg(datacfg);
@@ -190,15 +234,25 @@
    char *name_list = option_find_str(options, "names", "data/names.list");
    char *prefix = option_find_str(options, "results", "results");
    char **names = get_labels(name_list);
    char *mapf = option_find_str(options, "map", 0);
    int *map = 0;
    if (mapf) map = read_map(mapf);
    char buff[1024];
    int coco = option_find_int_quiet(options, "coco", 0);
    FILE *coco_fp = 0;
    if(coco){
    char *type = option_find_str(options, "eval", "voc");
    FILE *fp = 0;
    int coco = 0;
    int imagenet = 0;
    if(0==strcmp(type, "coco")){
        snprintf(buff, 1024, "%s/coco_results.json", prefix);
        coco_fp = fopen(buff, "w");
        fprintf(coco_fp, "[\n");
        fp = fopen(buff, "w");
        fprintf(fp, "[\n");
        coco = 1;
    } else if(0==strcmp(type, "imagenet")){
        snprintf(buff, 1024, "%s/imagenet-detection.txt", prefix);
        fp = fopen(buff, "w");
        imagenet = 1;
    }
    network net = parse_network_cfg(cfgfile);
@@ -230,10 +284,10 @@
    int i=0;
    int t;
    float thresh = .001;
    float nms = .5;
    float thresh = .005;
    float nms = .45;
    int nthreads = 2;
    int nthreads = 4;
    image *val = calloc(nthreads, sizeof(image));
    image *val_resized = calloc(nthreads, sizeof(image));
    image *buf = calloc(nthreads, sizeof(image));
@@ -274,9 +328,11 @@
            int h = val[t].h;
            get_region_boxes(l, w, h, thresh, probs, boxes, 0);
            if (nms) do_nms_sort(boxes, probs, l.w*l.h*l.n, classes, nms);
            if(coco_fp){
                print_cocos(coco_fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
            }else{
            if (coco){
                print_cocos(fp, path, boxes, probs, l.w*l.h*l.n, classes, w, h);
            } else if (imagenet){
                print_imagenet_detections(fp, i+t-nthreads+1 + 9741, boxes, probs, l.w*l.h*l.n, 200, w, h, map);
            } else {
                print_detector_detections(fps, id, boxes, probs, l.w*l.h*l.n, classes, w, h);
            }
            free(id);
@@ -287,10 +343,10 @@
    for(j = 0; j < classes; ++j){
        fclose(fps[j]);
    }
    if(coco_fp){
        fseek(coco_fp, -2, SEEK_CUR);
        fprintf(coco_fp, "\n]\n");
        fclose(coco_fp);
    if(coco){
        fseek(fp, -2, SEEK_CUR);
        fprintf(fp, "\n]\n");
        fclose(fp);
    }
    fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
src/layer.h
@@ -120,6 +120,7 @@
    int random;
    float thresh;
    int classfix;
    int absolute;
    int dontload;
    int dontloadscales;
src/network.c
@@ -41,7 +41,7 @@
    net.momentum = 0;
    net.decay = 0;
    #ifdef GPU
        if(gpu_index >= 0) update_network_gpu(net);
        //if(net.gpu_index >= 0) update_network_gpu(net);
    #endif
}
@@ -60,7 +60,7 @@
            for(i = 0; i < net.num_steps; ++i){
                if(net.steps[i] > batch_num) return rate;
                rate *= net.scales[i];
                if(net.steps[i] > batch_num - 1) reset_momentum(net);
                //if(net.steps[i] > batch_num - 1 && net.scales[i] > 1) reset_momentum(net);
            }
            return rate;
        case EXP:
@@ -321,6 +321,12 @@
int resize_network(network *net, int w, int h)
{
#ifdef GPU
    cuda_set_device(net->gpu_index);
    if(gpu_index >= 0){
        cuda_free(net->workspace);
    }
#endif
    int i;
    //if(w == net->w && h == net->h) return 0;
    net->w = w;
@@ -337,6 +343,10 @@
            resize_crop_layer(&l, w, h);
        }else if(l.type == MAXPOOL){
            resize_maxpool_layer(&l, w, h);
        }else if(l.type == REGION){
            resize_region_layer(&l, w, h);
        }else if(l.type == ROUTE){
            resize_route_layer(&l, net);
        }else if(l.type == REORG){
            resize_reorg_layer(&l, w, h);
        }else if(l.type == AVGPOOL){
@@ -357,7 +367,12 @@
    }
#ifdef GPU
    if(gpu_index >= 0){
        cuda_free(net->workspace);
        if(net->input_gpu) {
            cuda_free(*net->input_gpu);
            *net->input_gpu = 0;
            cuda_free(*net->truth_gpu);
            *net->truth_gpu = 0;
        }
        net->workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
    }else {
        free(net->workspace);
src/network_kernels.cu
@@ -78,6 +78,7 @@
void update_network_gpu(network net)
{
    cuda_set_device(net.gpu_index);
    int i;
    int update_batch = net.batch*net.subdivisions;
    float rate = get_current_rate(net);
@@ -377,7 +378,7 @@
float *get_network_output_layer_gpu(network net, int i)
{
    layer l = net.layers[i];
    cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
    if(l.type != REGION) cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
    return l.output;
}
src/parser.c
@@ -2,32 +2,32 @@
#include <string.h>
#include <stdlib.h>
#include "blas.h"
#include "parser.h"
#include "assert.h"
#include "activations.h"
#include "crop_layer.h"
#include "cost_layer.h"
#include "convolutional_layer.h"
#include "activation_layer.h"
#include "normalization_layer.h"
#include "batchnorm_layer.h"
#include "connected_layer.h"
#include "rnn_layer.h"
#include "gru_layer.h"
#include "crnn_layer.h"
#include "maxpool_layer.h"
#include "reorg_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "region_layer.h"
#include "activations.h"
#include "assert.h"
#include "avgpool_layer.h"
#include "batchnorm_layer.h"
#include "blas.h"
#include "connected_layer.h"
#include "convolutional_layer.h"
#include "cost_layer.h"
#include "crnn_layer.h"
#include "crop_layer.h"
#include "detection_layer.h"
#include "dropout_layer.h"
#include "gru_layer.h"
#include "list.h"
#include "local_layer.h"
#include "maxpool_layer.h"
#include "normalization_layer.h"
#include "option_list.h"
#include "parser.h"
#include "region_layer.h"
#include "reorg_layer.h"
#include "rnn_layer.h"
#include "route_layer.h"
#include "shortcut_layer.h"
#include "list.h"
#include "option_list.h"
#include "softmax_layer.h"
#include "utils.h"
typedef struct{
@@ -232,21 +232,6 @@
    return layer;
}
int *read_map(char *filename)
{
    int n = 0;
    int *map = 0;
    char *str;
    FILE *file = fopen(filename, "r");
    if(!file) file_error(filename);
    while((str=fgetl(file))){
        ++n;
        map = realloc(map, n*sizeof(int));
        map[n-1] = atoi(str);
    }
    return map;
}
layer parse_region(list *options, size_params params)
{
    int coords = option_find_int(options, "coords", 4);
@@ -269,6 +254,8 @@
    l.thresh = option_find_float(options, "thresh", .5);
    l.classfix = option_find_int_quiet(options, "classfix", 0);
    l.absolute = option_find_int_quiet(options, "absolute", 0);
    l.random = option_find_int_quiet(options, "random", 0);
    l.coord_scale = option_find_float(options, "coord_scale", 1);
    l.object_scale = option_find_float(options, "object_scale", 1);
src/region_layer.c
@@ -9,6 +9,8 @@
#include <string.h>
#include <stdlib.h>
#define DOABS 1
region_layer make_region_layer(int batch, int w, int h, int n, int classes, int coords)
{
    region_layer l = {0};
@@ -48,7 +50,26 @@
    return l;
}
#define DOABS 1
void resize_region_layer(layer *l, int w, int h)
{
    l->w = w;
    l->h = h;
    l->outputs = h*w*l->n*(l->classes + l->coords + 1);
    l->inputs = l->outputs;
    l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
    l->delta = realloc(l->delta, l->batch*l->outputs*sizeof(float));
#ifdef GPU
    cuda_free(l->delta_gpu);
    cuda_free(l->output_gpu);
    l->delta_gpu =     cuda_make_array(l->delta, l->batch*l->outputs);
    l->output_gpu =    cuda_make_array(l->output, l->batch*l->outputs);
#endif
}
box get_region_box(float *x, float *biases, int n, int index, int i, int j, int w, int h)
{
    box b;
@@ -125,7 +146,9 @@
    int i,j,b,t,n;
    int size = l.coords + l.classes + 1;
    memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
    reorg(l.output, l.w*l.h, size*l.n, l.batch, 1);
    #ifndef GPU
    flatten(l.output, l.w*l.h, size*l.n, l.batch, 1);
    #endif
    for (b = 0; b < l.batch; ++b){
        for(i = 0; i < l.h*l.w*l.n; ++i){
            int index = size*i + b*l.outputs;
@@ -134,25 +157,14 @@
    }
#ifndef GPU
    if (l.softmax_tree){
#ifdef GPU
        cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
        int i;
        int count = 5;
        for (i = 0; i < l.softmax_tree->groups; ++i) {
            int group_size = l.softmax_tree->group_size[i];
            softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count);
            count += group_size;
        }
        cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
#else
        for (b = 0; b < l.batch; ++b){
            for(i = 0; i < l.h*l.w*l.n; ++i){
                int index = size*i + b*l.outputs;
                softmax_tree(l.output + index + 5, 1, 0, 1, l.softmax_tree, l.output + index + 5);
            }
        }
#endif
    } else if (l.softmax){
        for (b = 0; b < l.batch; ++b){
            for(i = 0; i < l.h*l.w*l.n; ++i){
@@ -161,6 +173,7 @@
            }
        }
    }
#endif
    if(!state.train) return;
    memset(l.delta, 0, l.outputs * l.batch * sizeof(float));
    float avg_iou = 0;
@@ -172,6 +185,32 @@
    int class_count = 0;
    *(l.cost) = 0;
    for (b = 0; b < l.batch; ++b) {
        if(l.softmax_tree){
            int onlyclass = 0;
            for(t = 0; t < 30; ++t){
                box truth = float_to_box(state.truth + t*5 + b*l.truths);
                if(!truth.x) break;
                int class = state.truth[t*5 + b*l.truths + 4];
                float maxp = 0;
                int maxi = 0;
                if(truth.x > 100000 && truth.y > 100000){
                    for(n = 0; n < l.n*l.w*l.h; ++n){
                        int index = size*n + b*l.outputs + 5;
                        float p = get_hierarchy_probability(l.output + index, l.softmax_tree, class);
                        if(p > maxp){
                            maxp = p;
                            maxi = n;
                        }
                    }
                    int index = size*maxi + b*l.outputs + 5;
                    delta_region_class(l.output, l.delta, index, class, l.classes, l.softmax_tree, l.class_scale, &avg_cat);
                    ++class_count;
                    onlyclass = 1;
                    break;
                }
            }
            if(onlyclass) continue;
        }
        for (j = 0; j < l.h; ++j) {
            for (i = 0; i < l.w; ++i) {
                for (n = 0; n < l.n; ++n) {
@@ -273,7 +312,9 @@
        }
    }
    //printf("\n");
    reorg(l.delta, l.w*l.h, size*l.n, l.batch, 0);
    #ifndef GPU
    flatten(l.delta, l.w*l.h, size*l.n, l.batch, 0);
    #endif
    *(l.cost) = pow(mag_array(l.delta, l.outputs * l.batch), 2);
    printf("Region Avg IOU: %f, Class: %f, Obj: %f, No Obj: %f, Avg Recall: %f,  count: %d\n", avg_iou/count, avg_cat/class_count, avg_obj/count, avg_anyobj/(l.w*l.h*l.n*l.batch), recall/count, count);
}
@@ -308,13 +349,18 @@
                hierarchy_predictions(predictions + class_index, l.classes, l.softmax_tree, 0);
                int found = 0;
                for(j = l.classes - 1; j >= 0; --j){
                    if(!found && predictions[class_index + j] > .5){
                        found = 1;
                    } else {
                        predictions[class_index + j] = 0;
                    if(1){
                        if(!found && predictions[class_index + j] > .5){
                            found = 1;
                        } else {
                            predictions[class_index + j] = 0;
                        }
                        float prob = predictions[class_index+j];
                        probs[index][j] = (scale > thresh) ? prob : 0;
                    }else{
                        float prob = scale*predictions[class_index+j];
                        probs[index][j] = (prob > thresh) ? prob : 0;
                    }
                    float prob = predictions[class_index+j];
                    probs[index][j] = (scale > thresh) ? prob : 0;
                }
            }else{
                for(j = 0; j < l.classes; ++j){
@@ -339,6 +385,18 @@
       return;
       }
     */
    flatten_ongpu(state.input, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 1, l.output_gpu);
    if(l.softmax_tree){
        int i;
        int count = 5;
        for (i = 0; i < l.softmax_tree->groups; ++i) {
            int group_size = l.softmax_tree->group_size[i];
            softmax_gpu(l.output_gpu+count, group_size, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + count);
            count += group_size;
        }
    }else if (l.softmax){
        softmax_gpu(l.output_gpu+5, l.classes, l.classes + 5, l.w*l.h*l.n*l.batch, 1, l.output_gpu + 5);
    }
    float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
    float *truth_cpu = 0;
@@ -347,22 +405,22 @@
        truth_cpu = calloc(num_truth, sizeof(float));
        cuda_pull_array(state.truth, truth_cpu, num_truth);
    }
    cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
    cuda_pull_array(l.output_gpu, in_cpu, l.batch*l.inputs);
    network_state cpu_state = state;
    cpu_state.train = state.train;
    cpu_state.truth = truth_cpu;
    cpu_state.input = in_cpu;
    forward_region_layer(l, cpu_state);
    cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
    cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
    //cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
    free(cpu_state.input);
    if(!state.train) return;
    cuda_push_array(l.delta_gpu, l.delta, l.batch*l.outputs);
    if(cpu_state.truth) free(cpu_state.truth);
}
void backward_region_layer_gpu(region_layer l, network_state state)
{
    axpy_ongpu(l.batch*l.outputs, 1, l.delta_gpu, 1, state.delta, 1);
    //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
    flatten_ongpu(l.delta_gpu, l.h*l.w, l.n*(l.coords + l.classes + 1), l.batch, 0, state.delta);
}
#endif
src/region_layer.h
@@ -10,6 +10,7 @@
void forward_region_layer(const region_layer l, network_state state);
void backward_region_layer(const region_layer l, network_state state);
void get_region_boxes(layer l, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
void resize_region_layer(layer *l, int w, int h);
#ifdef GPU
void forward_region_layer_gpu(const region_layer l, network_state state);
src/reorg_layer.c
@@ -22,6 +22,7 @@
        l.out_h = h/stride;
        l.out_c = c*(stride*stride);
    }
    l.reverse = reverse;
    fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c);
    l.outputs = l.out_h * l.out_w * l.out_c;
    l.inputs = h*w*c;
@@ -44,12 +45,20 @@
void resize_reorg_layer(layer *l, int w, int h)
{
    int stride = l->stride;
    int c = l->c;
    l->h = h;
    l->w = w;
    l->out_w = w*stride;
    l->out_h = h*stride;
    if(l->reverse){
        l->out_w = w*stride;
        l->out_h = h*stride;
        l->out_c = c/(stride*stride);
    }else{
        l->out_w = w/stride;
        l->out_h = h/stride;
        l->out_c = c*(stride*stride);
    }
    l->outputs = l->out_h * l->out_w * l->out_c;
    l->inputs = l->outputs;
src/route_layer.c
@@ -36,6 +36,40 @@
    return l;
}
void resize_route_layer(route_layer *l, network *net)
{
    int i;
    layer first = net->layers[l->input_layers[0]];
    l->out_w = first.out_w;
    l->out_h = first.out_h;
    l->out_c = first.out_c;
    l->outputs = first.outputs;
    l->input_sizes[0] = first.outputs;
    for(i = 1; i < l->n; ++i){
        int index = l->input_layers[i];
        layer next = net->layers[index];
        l->outputs += next.outputs;
        l->input_sizes[i] = next.outputs;
        if(next.out_w == first.out_w && next.out_h == first.out_h){
            l->out_c += next.out_c;
        }else{
            printf("%d %d, %d %d\n", next.out_w, next.out_h, first.out_w, first.out_h);
            l->out_h = l->out_w = l->out_c = 0;
        }
    }
    l->inputs = l->outputs;
    l->delta =  realloc(l->delta, l->outputs*l->batch*sizeof(float));
    l->output = realloc(l->output, l->outputs*l->batch*sizeof(float));
#ifdef GPU
    cuda_free(l->output_gpu);
    cuda_free(l->delta_gpu);
    l->output_gpu  = cuda_make_array(l->output, l->outputs*l->batch);
    l->delta_gpu   = cuda_make_array(l->delta,  l->outputs*l->batch);
#endif
}
void forward_route_layer(const route_layer l, network_state state)
{
    int i, j;
src/route_layer.h
@@ -8,6 +8,7 @@
route_layer make_route_layer(int batch, int n, int *input_layers, int *input_size);
void forward_route_layer(const route_layer l, network_state state);
void backward_route_layer(const route_layer l, network_state state);
void resize_route_layer(route_layer *l, network *net);
#ifdef GPU
void forward_route_layer_gpu(const route_layer l, network_state state);
src/tree.c
@@ -24,6 +24,16 @@
    fprintf(stderr, "Found %d leaves.\n", found);
}
float get_hierarchy_probability(float *x, tree *hier, int c)
{
    float p = 1;
    while(c >= 0){
        p = p * x[c];
        c = hier->parent[c];
    }
    return p;
}
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
    int j;
src/tree.h
@@ -16,5 +16,6 @@
tree *read_tree(char *filename);
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves);
void change_leaves(tree *t, char *leaf_list);
float get_hierarchy_probability(float *x, tree *hier, int c);
#endif
src/utils.c
@@ -9,6 +9,21 @@
#include "utils.h"
int *read_map(char *filename)
{
    int n = 0;
    int *map = 0;
    char *str;
    FILE *file = fopen(filename, "r");
    if(!file) file_error(filename);
    while((str=fgetl(file))){
        ++n;
        map = realloc(map, n*sizeof(int));
        map[n-1] = atoi(str);
    }
    return map;
}
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections)
{
    size_t i;
src/utils.h
@@ -7,6 +7,7 @@
#define SECRET_NUM -1234
#define TWO_PI 6.2831853071795864769252866
int *read_map(char *filename);
void shuffle(void *arr, size_t n, size_t size);
void sorta_shuffle(void *arr, size_t n, size_t size, size_t sections);
void free_ptrs(void **ptrs, int n);