Joseph Redmon
2015-07-13 8561e49b5a2876e9a522b2dedfa99f19d5738154
add avgpool layer
7 files modified
3 files added
204 ■■■■■ changed files
Makefile 4 ●●●● patch | view | raw | blame | history
src/avgpool_layer.c 66 ●●●●● patch | view | raw | blame | history
src/avgpool_layer.h 23 ●●●●● patch | view | raw | blame | history
src/avgpool_layer_kernels.cu 57 ●●●●● patch | view | raw | blame | history
src/imagenet.c 6 ●●●● patch | view | raw | blame | history
src/layer.h 3 ●●●● patch | view | raw | blame | history
src/network.c 10 ●●●●● patch | view | raw | blame | history
src/network_kernels.cu 5 ●●●●● patch | view | raw | blame | history
src/normalization_layer.c 8 ●●●● patch | view | raw | blame | history
src/parser.c 22 ●●●●● patch | view | raw | blame | history
Makefile
@@ -34,9 +34,9 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o
OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o detection.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o
ifeq ($(GPU), 1) 
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif
OBJS = $(addprefix $(OBJDIR), $(OBJ))
src/avgpool_layer.c
New file
@@ -0,0 +1,66 @@
#include "avgpool_layer.h"
#include "cuda.h"
#include <stdio.h>
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c)
{
    fprintf(stderr, "Avgpool Layer: %d x %d x %d image\n", w,h,c);
    avgpool_layer l = {0};
    l.type = AVGPOOL;
    l.batch = batch;
    l.h = h;
    l.w = w;
    l.c = c;
    l.out_w = 1;
    l.out_h = 1;
    l.out_c = c;
    l.outputs = l.out_c;
    l.inputs = h*w*c;
    int output_size = l.outputs * batch;
    l.output =  calloc(output_size, sizeof(float));
    l.delta =   calloc(output_size, sizeof(float));
    #ifdef GPU
    l.output_gpu  = cuda_make_array(l.output, output_size);
    l.delta_gpu   = cuda_make_array(l.delta, output_size);
    #endif
    return l;
}
void resize_avgpool_layer(avgpool_layer *l, int w, int h)
{
    l->h = h;
    l->w = w;
}
void forward_avgpool_layer(const avgpool_layer l, network_state state)
{
    int b,i,k;
    for(b = 0; b < l.batch; ++b){
        for(k = 0; k < l.c; ++k){
            int out_index = k + b*l.c;
            l.output[out_index] = 0;
            for(i = 0; i < l.h*l.w; ++i){
                int in_index = i + l.h*l.w*(k + b*l.c);
                l.output[out_index] += state.input[in_index];
            }
            l.output[out_index] /= l.h*l.w;
        }
    }
}
void backward_avgpool_layer(const avgpool_layer l, network_state state)
{
    int b,i,k;
    for(b = 0; b < l.batch; ++b){
        for(k = 0; k < l.c; ++k){
            int out_index = k + b*l.c;
            for(i = 0; i < l.h*l.w; ++i){
                int in_index = i + l.h*l.w*(k + b*l.c);
                state.delta[in_index] = l.delta[out_index] / (l.h*l.w);
            }
        }
    }
}
src/avgpool_layer.h
New file
@@ -0,0 +1,23 @@
#ifndef AVGPOOL_LAYER_H
#define AVGPOOL_LAYER_H
#include "image.h"
#include "params.h"
#include "cuda.h"
#include "layer.h"
typedef layer avgpool_layer;
image get_avgpool_image(avgpool_layer l);
avgpool_layer make_avgpool_layer(int batch, int w, int h, int c);
void resize_avgpool_layer(avgpool_layer *l, int w, int h);
void forward_avgpool_layer(const avgpool_layer l, network_state state);
void backward_avgpool_layer(const avgpool_layer l, network_state state);
#ifdef GPU
void forward_avgpool_layer_gpu(avgpool_layer l, network_state state);
void backward_avgpool_layer_gpu(avgpool_layer l, network_state state);
#endif
#endif
src/avgpool_layer_kernels.cu
New file
@@ -0,0 +1,57 @@
extern "C" {
#include "avgpool_layer.h"
#include "cuda.h"
}
__global__ void forward_avgpool_layer_kernel(int n, int w, int h, int c, float *input, float *output)
{
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(id >= n) return;
    int k = id % c;
    id /= c;
    int b = id;
    int i;
    int out_index = (k + c*b);
    output[out_index] = 0;
    for(i = 0; i < w*h; ++i){
        int in_index = i + h*w*(k + b*c);
        output[out_index] += input[in_index];
    }
    output[out_index] /= w*h;
}
__global__ void backward_avgpool_layer_kernel(int n, int w, int h, int c, float *in_delta, float *out_delta)
{
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(id >= n) return;
    int k = id % c;
    id /= c;
    int b = id;
    int i;
    int out_index = (k + c*b);
    for(i = 0; i < w*h; ++i){
        int in_index = i + h*w*(k + b*c);
        in_delta[in_index] = out_delta[out_index] / (w*h);
    }
}
extern "C" void forward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
{
    size_t n = layer.c*layer.batch;
    forward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.input, layer.output_gpu);
    check_error(cudaPeekAtLastError());
}
extern "C" void backward_avgpool_layer_gpu(avgpool_layer layer, network_state state)
{
    size_t n = layer.c*layer.batch;
    backward_avgpool_layer_kernel<<<cuda_gridsize(n), BLOCK>>>(n, layer.w, layer.h, layer.c, state.delta, layer.delta_gpu);
    check_error(cudaPeekAtLastError());
}
src/imagenet.c
@@ -25,7 +25,7 @@
    pthread_t load_thread;
    data train;
    data buffer;
    load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
    load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
    while(1){
        ++i;
        time=clock();
@@ -38,7 +38,7 @@
        cvWaitKey(0);
        */
        load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, 256, 256, &buffer);
        load_thread = load_data_thread(paths, imgs, plist->size, labels, 1000, net.w, net.h, &buffer);
        printf("Loaded: %lf seconds\n", sec(clock()-time));
        time=clock();
        float loss = train_network(net, train);
@@ -47,7 +47,7 @@
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
        free_data(train);
        if((i % 30000) == 0) net.learning_rate *= .1;
        if((i % 20000) == 0) net.learning_rate *= .1;
        if(i%1000==0){
            char buff[256];
            sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
src/layer.h
@@ -14,7 +14,8 @@
    CROP,
    ROUTE,
    COST,
    NORMALIZATION
    NORMALIZATION,
    AVGPOOL
} LAYER_TYPE;
typedef enum{
src/network.c
@@ -12,6 +12,7 @@
#include "detection_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
#include "dropout_layer.h"
@@ -28,6 +29,8 @@
            return "connected";
        case MAXPOOL:
            return "maxpool";
        case AVGPOOL:
            return "avgpool";
        case SOFTMAX:
            return "softmax";
        case DETECTION:
@@ -83,6 +86,8 @@
            forward_softmax_layer(l, state);
        } else if(l.type == MAXPOOL){
            forward_maxpool_layer(l, state);
        } else if(l.type == AVGPOOL){
            forward_avgpool_layer(l, state);
        } else if(l.type == DROPOUT){
            forward_dropout_layer(l, state);
        } else if(l.type == ROUTE){
@@ -156,6 +161,8 @@
            backward_normalization_layer(l, state);
        } else if(l.type == MAXPOOL){
            if(i != 0) backward_maxpool_layer(l, state);
        } else if(l.type == AVGPOOL){
            backward_avgpool_layer(l, state);
        } else if(l.type == DROPOUT){
            backward_dropout_layer(l, state);
        } else if(l.type == DETECTION){
@@ -273,6 +280,9 @@
            resize_convolutional_layer(&l, w, h);
        }else if(l.type == MAXPOOL){
            resize_maxpool_layer(&l, w, h);
        }else if(l.type == AVGPOOL){
            resize_avgpool_layer(&l, w, h);
            break;
        }else if(l.type == NORMALIZATION){
            resize_normalization_layer(&l, w, h);
        }else{
src/network_kernels.cu
@@ -15,6 +15,7 @@
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
#include "normalization_layer.h"
#include "cost_layer.h"
#include "softmax_layer.h"
@@ -49,6 +50,8 @@
            forward_normalization_layer_gpu(l, state);
        } else if(l.type == MAXPOOL){
            forward_maxpool_layer_gpu(l, state);
        } else if(l.type == AVGPOOL){
            forward_avgpool_layer_gpu(l, state);
        } else if(l.type == DROPOUT){
            forward_dropout_layer_gpu(l, state);
        } else if(l.type == ROUTE){
@@ -79,6 +82,8 @@
            backward_deconvolutional_layer_gpu(l, state);
        } else if(l.type == MAXPOOL){
            if(i != 0) backward_maxpool_layer_gpu(l, state);
        } else if(l.type == AVGPOOL){
            if(i != 0) backward_avgpool_layer_gpu(l, state);
        } else if(l.type == DROPOUT){
            backward_dropout_layer_gpu(l, state);
        } else if(l.type == DETECTION){
src/normalization_layer.c
@@ -40,10 +40,10 @@
    layer->out_w = w;
    layer->inputs = w*h*c;
    layer->outputs = layer->inputs;
    layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
    layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
    layer->squared = realloc(layer->squared, h * w * layer->c * layer->batch * sizeof(float));
    layer->norms = realloc(layer->norms, h * w * layer->c * layer->batch * sizeof(float));
    layer->output = realloc(layer->output, h * w * c * batch * sizeof(float));
    layer->delta = realloc(layer->delta, h * w * c * batch * sizeof(float));
    layer->squared = realloc(layer->squared, h * w * c * batch * sizeof(float));
    layer->norms = realloc(layer->norms, h * w * c * batch * sizeof(float));
#ifdef GPU
    cuda_free(layer->output_gpu);
    cuda_free(layer->delta_gpu); 
src/parser.c
@@ -14,6 +14,7 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
#include "avgpool_layer.h"
#include "route_layer.h"
#include "list.h"
#include "option_list.h"
@@ -29,6 +30,7 @@
int is_deconvolutional(section *s);
int is_connected(section *s);
int is_maxpool(section *s);
int is_avgpool(section *s);
int is_dropout(section *s);
int is_softmax(section *s);
int is_normalization(section *s);
@@ -214,6 +216,19 @@
    return layer;
}
avgpool_layer parse_avgpool(list *options, size_params params)
{
    int batch,w,h,c;
    w = params.w;
    h = params.h;
    c = params.c;
    batch=params.batch;
    if(!(h && w && c)) error("Layer before avgpool layer must output image.");
    avgpool_layer layer = make_avgpool_layer(batch,w,h,c);
    return layer;
}
dropout_layer parse_dropout(list *options, size_params params)
{
    float probability = option_find_float(options, "probability", .5);
@@ -333,6 +348,8 @@
            l = parse_normalization(options, params);
        }else if(is_maxpool(s)){
            l = parse_maxpool(options, params);
        }else if(is_avgpool(s)){
            l = parse_avgpool(options, params);
        }else if(is_route(s)){
            l = parse_route(options, params, net);
        }else if(is_dropout(s)){
@@ -402,6 +419,11 @@
    return (strcmp(s->type, "[max]")==0
            || strcmp(s->type, "[maxpool]")==0);
}
int is_avgpool(section *s)
{
    return (strcmp(s->type, "[avg]")==0
            || strcmp(s->type, "[avgpool]")==0);
}
int is_dropout(section *s)
{
    return (strcmp(s->type, "[dropout]")==0);