Joseph Redmon
2015-11-14 0cd2379e2ccbad07bad3f88f8dc564776605802d
some changes
10 files modified
4 files added
872 ■■■■■ changed files
Makefile 2 ●●● patch | view | raw | blame | history
src/coco.c 8 ●●●● patch | view | raw | blame | history
src/coco_kernels.cu 109 ●●●●● patch | view | raw | blame | history
src/data.c 4 ●●● patch | view | raw | blame | history
src/layer.h 3 ●●●● patch | view | raw | blame | history
src/local_kernels.cu 226 ●●●●● patch | view | raw | blame | history
src/local_layer.c 275 ●●●●● patch | view | raw | blame | history
src/local_layer.h 31 ●●●●● patch | view | raw | blame | history
src/network.c 9 ●●●●● patch | view | raw | blame | history
src/network_kernels.cu 7 ●●●●● patch | view | raw | blame | history
src/nightmare.c 7 ●●●●● patch | view | raw | blame | history
src/parser.c 50 ●●●●● patch | view | raw | blame | history
src/yolo.c 26 ●●●●● patch | view | raw | blame | history
src/yolo_kernels.cu 115 ●●●● patch | view | raw | blame | history
Makefile
@@ -34,7 +34,7 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o local_layer.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o yolo_kernels.o
endif
src/coco.c
@@ -15,7 +15,7 @@
int coco_ids[] = {1,2,3,4,5,6,7,8,9,10,11,13,14,15,16,17,18,19,20,21,22,23,24,25,27,28,31,32,33,34,35,36,37,38,39,40,41,42,43,44,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,67,70,72,73,74,75,76,77,78,79,80,81,82,84,85,86,87,88,89,90};
void draw_coco(image im, int num, float thresh, box *boxes, float **probs, char *label)
void draw_coco(image im, int num, float thresh, box *boxes, float **probs)
{
    int classes = 80;
    int i;
@@ -38,7 +38,6 @@
            draw_box_width(im, left, top, right, bot, width, red, green, blue);
        }
    }
    show_image(im, label);
}
void train_coco(char *cfgfile, char *weightfile)
@@ -215,7 +214,7 @@
    int i=0;
    int t;
    float thresh = .001;
    float thresh = .01;
    int nms = 1;
    float iou_thresh = .5;
@@ -393,7 +392,8 @@
        float *predictions = network_predict(net, X);
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        convert_coco_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        draw_coco(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
        draw_coco(im, l.side*l.side*l.n, thresh, boxes, probs);
        show_image(im, "predictions");
        show_image(sized, "resized");
        free_image(im);
src/coco_kernels.cu
New file
@@ -0,0 +1,109 @@
extern "C" {
#include "network.h"
#include "detection_layer.h"
#include "cost_layer.h"
#include "utils.h"
#include "parser.h"
#include "box.h"
#include "image.h"
}
#ifdef OPENCV
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
extern "C" image ipl_to_image(IplImage* src);
extern "C" void convert_coco_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
extern "C" void draw_coco(image im, int num, float thresh, box *boxes, float **probs);
static float **probs;
static box *boxes;
static network net;
static image in   ;
static image in_s ;
static image det  ;
static image det_s;
static image disp ;
static cv::VideoCapture cap(0);
void *fetch_in_thread(void *ptr)
{
    cv::Mat frame_m;
    cap >> frame_m;
    IplImage frame = frame_m;
    in = ipl_to_image(&frame);
    rgbgr_image(in);
    in_s = resize_image(in, net.w, net.h);
    return 0;
}
void *detect_in_thread(void *ptr)
{
    float nms = .4;
    float thresh = .2;
    detection_layer l = net.layers[net.n-1];
    float *X = det_s.data;
    float *predictions = network_predict(net, X);
    free_image(det_s);
    convert_coco_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
    if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms);
    printf("\033[2J");
    printf("\033[1;1H");
    printf("\nObjects:\n\n");
    draw_coco(det, l.side*l.side*l.n, thresh, boxes, probs);
    return 0;
}
extern "C" void demo_coco(char *cfgfile, char *weightfile, float thresh)
{
    printf("YOLO demo\n");
    net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    srand(2222222);
    if(!cap.isOpened()) error("Couldn't connect to webcam.\n");
    detection_layer l = net.layers[net.n-1];
    int j;
    boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box));
    probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *));
    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *));
    pthread_t fetch_thread;
    pthread_t detect_thread;
    fetch_in_thread(0);
    det = in;
    det_s = in_s;
    fetch_in_thread(0);
    detect_in_thread(0);
    disp = det;
    det = in;
    det_s = in_s;
    while(1){
        if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
        if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
        show_image(disp, "YOLO");
        free_image(disp);
        cvWaitKey(1);
        pthread_join(fetch_thread, 0);
        pthread_join(detect_thread, 0);
        disp  = det;
        det   = in;
        det_s = in_s;
    }
}
#else
extern "C" void demo_coco(char *cfgfile, char *weightfile, float thresh){
    fprintf(stderr, "YOLO-COCO demo needs OpenCV for webcam images.\n");
}
#endif
src/data.c
@@ -574,9 +574,7 @@
    pthread_t thread;
    struct load_args *ptr = calloc(1, sizeof(struct load_args));
    *ptr = args;
    if(pthread_create(&thread, 0, load_thread, ptr)) {
        error("Thread creation failed");
    }
    if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
    return thread;
}
src/layer.h
@@ -15,7 +15,8 @@
    ROUTE,
    COST,
    NORMALIZATION,
    AVGPOOL
    AVGPOOL,
    LOCAL
} LAYER_TYPE;
typedef enum{
src/local_kernels.cu
New file
@@ -0,0 +1,226 @@
extern "C" {
#include "local_layer.h"
#include "gemm.h"
#include "blas.h"
#include "im2col.h"
#include "col2im.h"
#include "utils.h"
#include "cuda.h"
}
__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
{
    int offset = blockIdx.x * blockDim.x + threadIdx.x;
    int filter = blockIdx.y;
    int batch = blockIdx.z;
    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
}
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
{
    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
    dim3 dimBlock(BLOCK, 1, 1);
    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
    check_error(cudaPeekAtLastError());
}
__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
{
    __shared__ float part[BLOCK];
    int i,b;
    int filter = blockIdx.x;
    int p = threadIdx.x;
    float sum = 0;
    for(b = 0; b < batch; ++b){
        for(i = 0; i < size; i += BLOCK){
            int index = p + i + size*(filter + n*b);
            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
        }
    }
    part[p] = sum;
    __syncthreads();
    if (p == 0) {
        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
    }
}
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
{
    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
    check_error(cudaPeekAtLastError());
}
__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
{
    int offset = blockIdx.x * blockDim.x + threadIdx.x;
    int filter = blockIdx.y;
    int batch = blockIdx.z;
    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
}
void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
{
    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
    dim3 dimBlock(BLOCK, 1, 1);
    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
    check_error(cudaPeekAtLastError());
}
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
{
    __shared__ float part[BLOCK];
    int i,b;
    int filter = blockIdx.x;
    int p = threadIdx.x;
    float sum = 0;
    for(b = 0; b < batch; ++b){
        for(i = 0; i < size; i += BLOCK){
            int index = p + i + size*(filter + n*b);
            sum += (p+i < size) ? delta[index] : 0;
        }
    }
    part[p] = sum;
    __syncthreads();
    if (p == 0) {
        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
    }
}
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
    check_error(cudaPeekAtLastError());
}
void forward_local_layer_gpu(local_layer l, network_state state)
{
    int i;
    int m = l.n;
    int k = l.size*l.size*l.c;
    int n = local_out_height(l)*
        local_out_width(l);
    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
    for(i = 0; i < l.batch; ++i){
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
        float * a = l.filters_gpu;
        float * b = l.col_image_gpu;
        float * c = l.output_gpu;
        gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
    }
    if(l.batch_normalize){
        if(state.train){
            fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_mean_gpu, l.mean_gpu);
            fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_variance_gpu, l.variance_gpu);
            scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
            axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
            scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
            axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
            // cuda_pull_array(l.variance_gpu, l.mean, l.n);
            // printf("%f\n", l.mean[0]);
            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
            normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
        } else {
            normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
        }
        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
    }
    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
    activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
}
void backward_local_layer_gpu(local_layer l, network_state state)
{
    int i;
    int m = l.n;
    int n = l.size*l.size*l.c;
    int k = local_out_height(l)*
        local_out_width(l);
    gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
    if(l.batch_normalize){
        backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
        scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
        fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_mean_delta_gpu, l.mean_delta_gpu);
        fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_variance_delta_gpu, l.variance_delta_gpu);
        normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
    }
    for(i = 0; i < l.batch; ++i){
        float * a = l.delta_gpu;
        float * b = l.col_image_gpu;
        float * c = l.filter_updates_gpu;
        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
        gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
        if(state.delta){
            float * a = l.filters_gpu;
            float * b = l.delta_gpu;
            float * c = l.col_image_gpu;
            gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
            col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
        }
    }
}
void pull_local_layer(local_layer layer)
{
    cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
    cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
    cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
    cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
    if (layer.batch_normalize){
        cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
        cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
        cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
    }
}
void push_local_layer(local_layer layer)
{
    cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
    cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
    cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
    cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
    if (layer.batch_normalize){
        cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
        cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
        cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
    }
}
void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    int size = layer.size*layer.size*layer.c*layer.n;
    axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
    axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
    axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
    axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
    scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
}
src/local_layer.c
New file
@@ -0,0 +1,275 @@
#include "local_layer.h"
#include "utils.h"
#include "im2col.h"
#include "col2im.h"
#include "blas.h"
#include "gemm.h"
#include <stdio.h>
#include <time.h>
int local_out_height(local_layer l)
{
    int h = l.h;
    if (!l.pad) h -= l.size;
    else h -= 1;
    return h/l.stride + 1;
}
int local_out_width(local_layer l)
{
    int w = l.w;
    if (!l.pad) w -= l.size;
    else w -= 1;
    return w/l.stride + 1;
}
local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation)
{
    int i;
    local_layer l = {0};
    l.type = LOCAL;
    l.h = h;
    l.w = w;
    l.c = c;
    l.n = n;
    l.batch = batch;
    l.stride = stride;
    l.size = size;
    l.pad = pad;
    int out_h = local_out_height(l);
    int out_w = local_out_width(l);
    int locations = out_h*out_w;
    l.out_h = out_h;
    l.out_w = out_w;
    l.out_c = n;
    l.outputs = l.out_h * l.out_w * l.out_c;
    l.inputs = l.w * l.h * l.c;
    l.filters = calloc(c*n*size*size*locations, sizeof(float));
    l.filter_updates = calloc(c*n*size*size*locations, sizeof(float));
    l.biases = calloc(l.outputs, sizeof(float));
    l.bias_updates = calloc(l.outputs, sizeof(float));
    // float scale = 1./sqrt(size*size*c);
    float scale = sqrt(2./(size*size*c));
    for(i = 0; i < c*n*size*size; ++i) l.filters[i] = 2*scale*rand_uniform() - scale;
    l.col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
    l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
    l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
#ifdef GPU
    l.filters_gpu = cuda_make_array(l.filters, c*n*size*size*locations);
    l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size*locations);
    l.biases_gpu = cuda_make_array(l.biases, l.outputs);
    l.bias_updates_gpu = cuda_make_array(l.bias_updates, l.outputs);
    l.col_image_gpu = cuda_make_array(l.col_image, out_h*out_w*size*size*c);
    l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
    l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
#endif
    l.activation = activation;
    fprintf(stderr, "Local Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
    return l;
}
void forward_local_layer(const local_layer l, network_state state)
{
    int out_h = local_out_height(l);
    int out_w = local_out_width(l);
    int i, j;
    int locations = out_h * out_w;
    for(i = 0; i < l.batch; ++i){
        copy_cpu(l.outputs, l.biases, 1, l.output + i*l.outputs, 1);
    }
    for(i = 0; i < l.batch; ++i){
        float *input = state.input + i*l.w*l.h*l.c;
        im2col_cpu(input, l.c, l.h, l.w,
                l.size, l.stride, l.pad, l.col_image);
        float *output = l.output + i*l.outputs;
        for(j = 0; j < locations; ++j){
            float *a = l.filters + j*l.size*l.size*l.c*l.n;
            float *b = l.col_image + j;
            float *c = output + j;
            int m = l.n;
            int n = 1;
            int k = l.size*l.size*l.c;
            gemm(0,0,m,n,k,1,a,k,b,locations,1,c,locations);
        }
    }
    activate_array(l.output, l.outputs*l.batch, l.activation);
}
void backward_local_layer(local_layer l, network_state state)
{
    int i, j;
    int locations = l.out_w*l.out_h;
    gradient_array(l.output, l.outputs*l.batch, l.activation, l.delta);
    for(i = 0; i < l.batch; ++i){
        axpy_cpu(l.outputs, 1, l.delta + i*l.outputs, 1, l.bias_updates, 1);
    }
    for(i = 0; i < l.batch; ++i){
        float *input = state.input + i*l.w*l.h*l.c;
        im2col_cpu(input, l.c, l.h, l.w,
                l.size, l.stride, l.pad, l.col_image);
        for(j = 0; j < locations; ++j){
            float *a = l.delta + i*l.outputs + j;
            float *b = l.col_image + j;
            float *c = l.filter_updates + j*l.size*l.size*l.c*l.n;
            int m = l.n;
            int n = l.size*l.size*l.c;
            int k = 1;
            gemm(0,1,m,n,k,1,a,locations,b,locations,1,c,n);
        }
        if(state.delta){
            for(j = 0; j < locations; ++j){
                float *a = l.filters + j*l.size*l.size*l.c*l.n;
                float *b = l.delta + i*l.outputs + j;
                float *c = l.col_image + j;
                int m = l.size*l.size*l.c;
                int n = 1;
                int k = l.n;
                gemm(1,0,m,n,k,1,a,m,b,locations,0,c,locations);
            }
            col2im_cpu(l.col_image, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
        }
    }
}
void update_local_layer(local_layer l, int batch, float learning_rate, float momentum, float decay)
{
    int locations = l.out_w*l.out_h;
    int size = l.size*l.size*l.c*l.n*locations;
    axpy_cpu(l.outputs, learning_rate/batch, l.bias_updates, 1, l.biases, 1);
    scal_cpu(l.outputs, momentum, l.bias_updates, 1);
    axpy_cpu(size, -decay*batch, l.filters, 1, l.filter_updates, 1);
    axpy_cpu(size, learning_rate/batch, l.filter_updates, 1, l.filters, 1);
    scal_cpu(size, momentum, l.filter_updates, 1);
}
#ifdef GPU
void forward_local_layer_gpu(const local_layer l, network_state state)
{
    int out_h = local_out_height(l);
    int out_w = local_out_width(l);
    int i, j;
    int locations = out_h * out_w;
    for(i = 0; i < l.batch; ++i){
        copy_ongpu(l.outputs, l.biases_gpu, 1, l.output_gpu + i*l.outputs, 1);
    }
    for(i = 0; i < l.batch; ++i){
        float *input = state.input + i*l.w*l.h*l.c;
        im2col_ongpu(input, l.c, l.h, l.w,
                l.size, l.stride, l.pad, l.col_image_gpu);
        float *output = l.output_gpu + i*l.outputs;
        for(j = 0; j < locations; ++j){
            float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
            float *b = l.col_image_gpu + j;
            float *c = output + j;
            int m = l.n;
            int n = 1;
            int k = l.size*l.size*l.c;
            gemm_ongpu(0,0,m,n,k,1,a,k,b,locations,1,c,locations);
        }
    }
    activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
}
void backward_local_layer_gpu(local_layer l, network_state state)
{
    int i, j;
    int locations = l.out_w*l.out_h;
    gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
    for(i = 0; i < l.batch; ++i){
        axpy_ongpu(l.outputs, 1, l.delta_gpu + i*l.outputs, 1, l.bias_updates_gpu, 1);
    }
    for(i = 0; i < l.batch; ++i){
        float *input = state.input + i*l.w*l.h*l.c;
        im2col_ongpu(input, l.c, l.h, l.w,
                l.size, l.stride, l.pad, l.col_image_gpu);
        for(j = 0; j < locations; ++j){
            float *a = l.delta_gpu + i*l.outputs + j;
            float *b = l.col_image_gpu + j;
            float *c = l.filter_updates_gpu + j*l.size*l.size*l.c*l.n;
            int m = l.n;
            int n = l.size*l.size*l.c;
            int k = 1;
            gemm_ongpu(0,1,m,n,k,1,a,locations,b,locations,1,c,n);
        }
        if(state.delta){
            for(j = 0; j < locations; ++j){
                float *a = l.filters_gpu + j*l.size*l.size*l.c*l.n;
                float *b = l.delta_gpu + i*l.outputs + j;
                float *c = l.col_image_gpu + j;
                int m = l.size*l.size*l.c;
                int n = 1;
                int k = l.n;
                gemm_ongpu(1,0,m,n,k,1,a,m,b,locations,0,c,locations);
            }
            col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta+i*l.c*l.h*l.w);
        }
    }
}
void update_local_layer_gpu(local_layer l, int batch, float learning_rate, float momentum, float decay)
{
    int locations = l.out_w*l.out_h;
    int size = l.size*l.size*l.c*l.n*locations;
    axpy_ongpu(l.outputs, learning_rate/batch, l.bias_updates_gpu, 1, l.biases_gpu, 1);
    scal_ongpu(l.outputs, momentum, l.bias_updates_gpu, 1);
    axpy_ongpu(size, -decay*batch, l.filters_gpu, 1, l.filter_updates_gpu, 1);
    axpy_ongpu(size, learning_rate/batch, l.filter_updates_gpu, 1, l.filters_gpu, 1);
    scal_ongpu(size, momentum, l.filter_updates_gpu, 1);
}
void pull_local_layer(local_layer l)
{
    int locations = l.out_w*l.out_h;
    int size = l.size*l.size*l.c*l.n*locations;
    cuda_pull_array(l.filters_gpu, l.filters, size);
    cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
}
void push_local_layer(local_layer l)
{
    int locations = l.out_w*l.out_h;
    int size = l.size*l.size*l.c*l.n*locations;
    cuda_push_array(l.filters_gpu, l.filters, size);
    cuda_push_array(l.biases_gpu, l.biases, l.outputs);
}
#endif
src/local_layer.h
New file
@@ -0,0 +1,31 @@
#ifndef LOCAL_LAYER_H
#define LOCAL_LAYER_H
#include "cuda.h"
#include "params.h"
#include "image.h"
#include "activations.h"
#include "layer.h"
typedef layer local_layer;
#ifdef GPU
void forward_local_layer_gpu(local_layer layer, network_state state);
void backward_local_layer_gpu(local_layer layer, network_state state);
void update_local_layer_gpu(local_layer layer, int batch, float learning_rate, float momentum, float decay);
void push_local_layer(local_layer layer);
void pull_local_layer(local_layer layer);
#endif
local_layer make_local_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation);
void forward_local_layer(const local_layer layer, network_state state);
void backward_local_layer(local_layer layer, network_state state);
void update_local_layer(local_layer layer, int batch, float learning_rate, float momentum, float decay);
void bias_output(float *output, float *biases, int batch, int n, int size);
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size);
#endif
src/network.c
@@ -8,6 +8,7 @@
#include "crop_layer.h"
#include "connected_layer.h"
#include "local_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
@@ -59,6 +60,8 @@
    switch(a){
        case CONVOLUTIONAL:
            return "convolutional";
        case LOCAL:
            return "local";
        case DECONVOLUTIONAL:
            return "deconvolutional";
        case CONNECTED:
@@ -112,6 +115,8 @@
            forward_convolutional_layer(l, state);
        } else if(l.type == DECONVOLUTIONAL){
            forward_deconvolutional_layer(l, state);
        } else if(l.type == LOCAL){
            forward_local_layer(l, state);
        } else if(l.type == NORMALIZATION){
            forward_normalization_layer(l, state);
        } else if(l.type == DETECTION){
@@ -150,6 +155,8 @@
            update_deconvolutional_layer(l, rate, net.momentum, net.decay);
        } else if(l.type == CONNECTED){
            update_connected_layer(l, update_batch, rate, net.momentum, net.decay);
        } else if(l.type == LOCAL){
            update_local_layer(l, update_batch, rate, net.momentum, net.decay);
        }
    }
}
@@ -219,6 +226,8 @@
            if(i != 0) backward_softmax_layer(l, state);
        } else if(l.type == CONNECTED){
            backward_connected_layer(l, state);
        } else if(l.type == LOCAL){
            backward_local_layer(l, state);
        } else if(l.type == COST){
            backward_cost_layer(l, state);
        } else if(l.type == ROUTE){
src/network_kernels.cu
@@ -19,6 +19,7 @@
#include "avgpool_layer.h"
#include "normalization_layer.h"
#include "cost_layer.h"
#include "local_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "route_layer.h"
@@ -41,6 +42,8 @@
            forward_convolutional_layer_gpu(l, state);
        } else if(l.type == DECONVOLUTIONAL){
            forward_deconvolutional_layer_gpu(l, state);
        } else if(l.type == LOCAL){
            forward_local_layer_gpu(l, state);
        } else if(l.type == DETECTION){
            forward_detection_layer_gpu(l, state);
        } else if(l.type == CONNECTED){
@@ -85,6 +88,8 @@
            backward_convolutional_layer_gpu(l, state);
        } else if(l.type == DECONVOLUTIONAL){
            backward_deconvolutional_layer_gpu(l, state);
        } else if(l.type == LOCAL){
            backward_local_layer_gpu(l, state);
        } else if(l.type == MAXPOOL){
            if(i != 0) backward_maxpool_layer_gpu(l, state);
        } else if(l.type == AVGPOOL){
@@ -120,6 +125,8 @@
            update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay);
        } else if(l.type == CONNECTED){
            update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
        } else if(l.type == LOCAL){
            update_local_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
        }
    }
}
src/nightmare.c
@@ -25,7 +25,7 @@
    }
}
void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh)
void optimize_picture(network *net, image orig, int max_layer, float scale, float rate, float thresh, int norm)
{
    scale_image(orig, 2);
    translate_image(orig, -1);
@@ -85,7 +85,7 @@
    //rate = rate / abs_mean(out.data, out.w*out.h*out.c);
    normalize_array(out.data, out.w*out.h*out.c);
    if(norm) normalize_array(out.data, out.w*out.h*out.c);
    axpy_cpu(orig.w*orig.h*orig.c, rate, out.data, 1, orig.data, 1);
    /*
@@ -123,6 +123,7 @@
    int max_layer = atoi(argv[5]);
    int range = find_int_arg(argc, argv, "-range", 1);
    int norm = find_int_arg(argc, argv, "-norm", 1);
    int rounds = find_int_arg(argc, argv, "-rounds", 1);
    int iters = find_int_arg(argc, argv, "-iters", 10);
    int octaves = find_int_arg(argc, argv, "-octaves", 4);
@@ -160,7 +161,7 @@
            fflush(stderr);
            int layer = max_layer + rand()%range - range/2;
            int octave = rand()%octaves;
            optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh);
            optimize_picture(&net, im, layer, 1/pow(1.33333333, octave), rate, thresh, norm);
        }
        fprintf(stderr, "done\n");
        if(0){
src/parser.c
@@ -15,6 +15,7 @@
#include "dropout_layer.h"
#include "detection_layer.h"
#include "avgpool_layer.h"
#include "local_layer.h"
#include "route_layer.h"
#include "list.h"
#include "option_list.h"
@@ -27,6 +28,7 @@
int is_network(section *s);
int is_convolutional(section *s);
int is_local(section *s);
int is_deconvolutional(section *s);
int is_connected(section *s);
int is_maxpool(section *s);
@@ -107,6 +109,27 @@
    return layer;
}
local_layer parse_local(list *options, size_params params)
{
    int n = option_find_int(options, "filters",1);
    int size = option_find_int(options, "size",1);
    int stride = option_find_int(options, "stride",1);
    int pad = option_find_int(options, "pad",0);
    char *activation_s = option_find_str(options, "activation", "logistic");
    ACTIVATION activation = get_activation(activation_s);
    int batch,h,w,c;
    h = params.h;
    w = params.w;
    c = params.c;
    batch=params.batch;
    if(!(h && w && c)) error("Layer before local layer must output image.");
    local_layer layer = make_local_layer(batch,h,w,c,n,size,stride,pad,activation);
    return layer;
}
convolutional_layer parse_convolutional(list *options, size_params params)
{
    int n = option_find_int(options, "filters",1);
@@ -402,6 +425,8 @@
        layer l = {0};
        if(is_convolutional(s)){
            l = parse_convolutional(options, params);
        }else if(is_local(s)){
            l = parse_local(options, params);
        }else if(is_deconvolutional(s)){
            l = parse_deconvolutional(options, params);
        }else if(is_connected(s)){
@@ -465,6 +490,10 @@
{
    return (strcmp(s->type, "[detection]")==0);
}
int is_local(section *s)
{
    return (strcmp(s->type, "[local]")==0);
}
int is_deconvolutional(section *s)
{
    return (strcmp(s->type, "[deconv]")==0
@@ -626,6 +655,16 @@
#endif
            fwrite(l.biases, sizeof(float), l.outputs, fp);
            fwrite(l.weights, sizeof(float), l.outputs*l.inputs, fp);
        } if(l.type == LOCAL){
#ifdef GPU
            if(gpu_index >= 0){
                pull_local_layer(l);
            }
#endif
            int locations = l.out_w*l.out_h;
            int size = l.size*l.size*l.c*l.n*locations;
            fwrite(l.biases, sizeof(float), l.outputs, fp);
            fwrite(l.filters, sizeof(float), size, fp);
        }
    }
    fclose(fp);
@@ -686,6 +725,17 @@
            }
#endif
        }
        if(l.type == LOCAL){
            int locations = l.out_w*l.out_h;
            int size = l.size*l.size*l.c*l.n*locations;
            fread(l.biases, sizeof(float), l.outputs, fp);
            fread(l.filters, sizeof(float), size, fp);
#ifdef GPU
            if(gpu_index >= 0){
                push_local_layer(l);
            }
#endif
        }
    }
    fprintf(stderr, "Done!\n");
    fclose(fp);
src/yolo.c
@@ -11,7 +11,7 @@
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label)
void draw_yolo(image im, int num, float thresh, box *boxes, float **probs)
{
    int classes = 20;
    int i;
@@ -20,8 +20,10 @@
        int class = max_index(probs[i], classes);
        float prob = probs[i][class];
        if(prob > thresh){
            int width = pow(prob, 1./2.)*10;
            printf("%f %s\n", prob, voc_names[class]);
            int width = pow(prob, 1./2.)*10+1;
            //width = 8;
            printf("%s: %.2f\n", voc_names[class], prob);
            class = class * 7 % 20;
            float red = get_color(0,class,classes);
            float green = get_color(1,class,classes);
            float blue = get_color(2,class,classes);
@@ -41,7 +43,6 @@
            draw_box_width(im, left, top, right, bot, width, red, green, blue);
        }
    }
    show_image(im, label);
}
void train_yolo(char *cfgfile, char *weightfile)
@@ -97,21 +98,13 @@
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        /*
           image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
           image copy = copy_image(im);
           draw_yolo(copy, train.y.vals[113], 7, "truth");
           cvWaitKey(0);
           free_image(copy);
         */
        time=clock();
        float loss = train_network(net, train);
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
        if(i%1000==0){
        if(i%1000==0 || i == 600){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
            save_weights(net, buff);
@@ -183,8 +176,8 @@
    srand(time(0));
    char *base = "results/comp4_det_test_";
    //list *plist = get_paths("data/voc.2007.test");
    list *plist = get_paths("data/voc.2012.test");
    list *plist = get_paths("data/voc.2007.test");
    //list *plist = get_paths("data/voc.2012.test");
    char **paths = (char **)list_to_array(plist);
    layer l = net.layers[net.n-1];
@@ -384,7 +377,8 @@
        printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
        convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
        draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
        draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs);
        show_image(im, "predictions");
        show_image(sized, "resized");
        free_image(im);
src/yolo_kernels.cu
@@ -6,6 +6,7 @@
#include "parser.h"
#include "box.h"
#include "image.h"
#include <sys/time.h>
}
#ifdef OPENCV
@@ -13,48 +14,108 @@
#include "opencv2/imgproc/imgproc.hpp"
extern "C" image ipl_to_image(IplImage* src);
extern "C" void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label);
extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs);
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh)
static float **probs;
static box *boxes;
static network net;
static image in   ;
static image in_s ;
static image det  ;
static image det_s;
static image disp ;
static cv::VideoCapture cap;
static float fps = 0;
void *fetch_in_thread(void *ptr)
{
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    detection_layer l = net.layers[net.n-1];
    cv::VideoCapture cap(0);
    set_batch_network(&net, 1);
    srand(2222222);
    float nms = .4;
    int j;
    box *boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box));
    float **probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *));
    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *));
    while(1){
        cv::Mat frame_m;
        cap >> frame_m;
        IplImage frame = frame_m;
        image im = ipl_to_image(&frame);
        rgbgr_image(im);
    in = ipl_to_image(&frame);
    rgbgr_image(in);
    in_s = resize_image(in, net.w, net.h);
    return 0;
}
        image sized = resize_image(im, net.w, net.h);
        float *X = sized.data;
void *detect_in_thread(void *ptr)
{
    float nms = .4;
    float thresh = .2;
    detection_layer l = net.layers[net.n-1];
    float *X = det_s.data;
        float *predictions = network_predict(net, X);
    free_image(det_s);
        convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
        if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms);
        printf("\033[2J");
        printf("\033[1;1H");
        printf("\nObjects:\n\n");
        draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
    printf("\nFPS:%.0f\n",fps);
    printf("Objects:\n\n");
    draw_yolo(det, l.side*l.side*l.n, thresh, boxes, probs);
    return 0;
}
        free_image(im);
        free_image(sized);
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh)
{
    printf("YOLO demo\n");
    net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    set_batch_network(&net, 1);
    srand(2222222);
    cv::VideoCapture cam(0);
    cap = cam;
    if(!cap.isOpened()) error("Couldn't connect to webcam.\n");
    detection_layer l = net.layers[net.n-1];
    int j;
    boxes = (box *)calloc(l.side*l.side*l.n, sizeof(box));
    probs = (float **)calloc(l.side*l.side*l.n, sizeof(float *));
    for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = (float *)calloc(l.classes, sizeof(float *));
    pthread_t fetch_thread;
    pthread_t detect_thread;
    fetch_in_thread(0);
    det = in;
    det_s = in_s;
    fetch_in_thread(0);
    detect_in_thread(0);
    disp = det;
    det = in;
    det_s = in_s;
    while(1){
        struct timeval tval_before, tval_after, tval_result;
        gettimeofday(&tval_before, NULL);
        if(pthread_create(&fetch_thread, 0, fetch_in_thread, 0)) error("Thread creation failed");
        if(pthread_create(&detect_thread, 0, detect_in_thread, 0)) error("Thread creation failed");
        show_image(disp, "YOLO");
        free_image(disp);
        cvWaitKey(1);
        pthread_join(fetch_thread, 0);
        pthread_join(detect_thread, 0);
        disp  = det;
        det   = in;
        det_s = in_s;
        gettimeofday(&tval_after, NULL);
        timersub(&tval_after, &tval_before, &tval_result);
        float curr = 1000000.f/((long int)tval_result.tv_usec);
        fps = .9*fps + .1*curr;
    }
}
#else
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){}
extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){
    fprintf(stderr, "YOLO demo needs OpenCV for webcam images.\n");
}
#endif