Joseph Redmon
2015-02-09 979d02126b1a597361934f86f50eeda31ff083fe
Generalizing conv layer so deconv is easier
4 files modified
120 ■■■■■ changed files
src/convolutional_kernels.cu 45 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 42 ●●●● patch | view | raw | blame | history
src/convolutional_layer.h 11 ●●●●● patch | view | raw | blame | history
src/darknet.c 22 ●●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu
@@ -8,7 +8,7 @@
#include "cuda.h"
}
__global__ void bias(int n, int size, float *biases, float *output)
__global__ void bias_output_kernel(float *output, float *biases, int n, int size)
{
    int offset = blockIdx.x * blockDim.x + threadIdx.x;
    int filter = blockIdx.y;
@@ -17,18 +17,16 @@
    if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
}
extern "C" void bias_output_gpu(const convolutional_layer layer)
extern "C" void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
{
    int size = convolutional_out_height(layer)*convolutional_out_width(layer);
    dim3 dimBlock(BLOCK, 1, 1);
    dim3 dimGrid((size-1)/BLOCK + 1, layer.n, layer.batch);
    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
    bias<<<dimGrid, dimBlock>>>(layer.n, size, layer.biases_gpu, layer.output_gpu);
    bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
    check_error(cudaPeekAtLastError());
}
__global__ void learn_bias(int batch, int n, int size, float *delta, float *bias_updates, float scale)
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size, float scale)
{
    __shared__ float part[BLOCK];
    int i,b;
@@ -48,36 +46,14 @@
    }
}
extern "C" void learn_bias_convolutional_layer_ongpu(convolutional_layer layer)
extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
    int size = convolutional_out_height(layer)*convolutional_out_width(layer);
    float alpha = 1./layer.batch;
    float alpha = 1./batch;
    learn_bias<<<layer.n, BLOCK>>>(layer.batch, layer.n, size, layer.delta_gpu, layer.bias_updates_gpu, alpha);
    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, alpha);
    check_error(cudaPeekAtLastError());
}
extern "C" void test_learn_bias(convolutional_layer l)
{
    int i;
    int size = convolutional_out_height(l) * convolutional_out_width(l);
    for(i = 0; i < size*l.batch*l.n; ++i){
        l.delta[i] = rand_uniform();
    }
    for(i = 0; i < l.n; ++i){
        l.bias_updates[i] = rand_uniform();
    }
    cuda_push_array(l.delta_gpu, l.delta, size*l.batch*l.n);
    cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
    float *gpu = (float *) calloc(l.n, sizeof(float));
    cuda_pull_array(l.bias_updates_gpu, gpu, l.n);
    for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
    learn_bias_convolutional_layer_ongpu(l);
    learn_bias_convolutional_layer(l);
    cuda_pull_array(l.bias_updates_gpu, gpu, l.n);
    for(i = 0; i < l.n; ++i) printf("%.9g %.9g\n", l.bias_updates[i], gpu[i]);
}
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float *in)
{
    int i;
@@ -86,7 +62,7 @@
    int n = convolutional_out_height(layer)*
        convolutional_out_width(layer);
    bias_output_gpu(layer);
    bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
    for(i = 0; i < layer.batch; ++i){
        im2col_ongpu(in, i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
@@ -106,8 +82,9 @@
    int n = layer.size*layer.size*layer.c;
    int k = convolutional_out_height(layer)*
        convolutional_out_width(layer);
    gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
    learn_bias_convolutional_layer_ongpu(layer);
    backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
    if(delta_gpu) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, delta_gpu, 1);
src/convolutional_layer.c
@@ -111,27 +111,37 @@
                                layer->batch*out_h * out_w * layer->n*sizeof(float));
}
void bias_output(const convolutional_layer layer)
void bias_output(float *output, float *biases, int batch, int n, int size)
{
    int i,j,b;
    int out_h = convolutional_out_height(layer);
    int out_w = convolutional_out_width(layer);
    for(b = 0; b < layer.batch; ++b){
        for(i = 0; i < layer.n; ++i){
            for(j = 0; j < out_h*out_w; ++j){
                layer.output[(b*layer.n + i)*out_h*out_w + j] = layer.biases[i];
    for(b = 0; b < batch; ++b){
        for(i = 0; i < n; ++i){
            for(j = 0; j < size; ++j){
                output[(b*n + i)*size + j] = biases[i];
            }
        }
    }
}
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size)
{
    float alpha = 1./batch;
    int i,b;
    for(b = 0; b < batch; ++b){
        for(i = 0; i < n; ++i){
            bias_updates[i] += alpha * sum_array(delta+size*(i+b*n), size);
        }
    }
}
void forward_convolutional_layer(const convolutional_layer layer, float *in)
{
    int out_h = convolutional_out_height(layer);
    int out_w = convolutional_out_width(layer);
    int i;
    bias_output(layer);
    bias_output(layer.output, layer.biases, layer.batch, layer.n, out_h*out_w);
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
@@ -151,19 +161,6 @@
    activate_array(layer.output, m*n*layer.batch, layer.activation);
}
void learn_bias_convolutional_layer(convolutional_layer layer)
{
    float alpha = 1./layer.batch;
    int i,b;
    int size = convolutional_out_height(layer)
        *convolutional_out_width(layer);
    for(b = 0; b < layer.batch; ++b){
        for(i = 0; i < layer.n; ++i){
            layer.bias_updates[i] += alpha * sum_array(layer.delta+size*(i+b*layer.n), size);
        }
    }
}
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta)
{
    float alpha = 1./layer.batch;
@@ -174,8 +171,7 @@
        convolutional_out_width(layer);
    gradient_array(layer.output, m*k*layer.batch, layer.activation, layer.delta);
    learn_bias_convolutional_layer(layer);
    backward_bias(layer.bias_updates, layer.delta, layer.batch, layer.n, k);
    if(delta) memset(delta, 0, layer.batch*layer.h*layer.w*layer.c*sizeof(float));
src/convolutional_layer.h
@@ -45,10 +45,12 @@
void forward_convolutional_layer_gpu(convolutional_layer layer, float * in);
void backward_convolutional_layer_gpu(convolutional_layer layer, float * in, float * delta_gpu);
void update_convolutional_layer_gpu(convolutional_layer layer);
void push_convolutional_layer(convolutional_layer layer);
void pull_convolutional_layer(convolutional_layer layer);
void learn_bias_convolutional_layer_ongpu(convolutional_layer layer);
void bias_output_gpu(const convolutional_layer layer);
void bias_output_gpu(float *output, float *biases, int batch, int n, int size);
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size);
#endif
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, float learning_rate, float momentum, float decay);
@@ -59,14 +61,15 @@
void backward_convolutional_layer(convolutional_layer layer, float *in, float *delta);
void bias_output(const convolutional_layer layer);
void bias_output(float *output, float *biases, int batch, int n, int size);
void backward_bias(float *bias_updates, float *delta, int batch, int n, int size);
image get_convolutional_image(convolutional_layer layer);
image get_convolutional_delta(convolutional_layer layer);
image get_convolutional_filter(convolutional_layer layer, int i);
int convolutional_out_height(convolutional_layer layer);
int convolutional_out_width(convolutional_layer layer);
void learn_bias_convolutional_layer(convolutional_layer layer);
#endif
src/darknet.c
@@ -225,8 +225,7 @@
void train_imagenet(char *cfgfile, char *weightfile)
{
    float avg_loss = -1;
    // TODO
    srand(0);
    srand(time(0));
    char *base = basename(cfgfile);
    printf("%s\n", base);
    network net = parse_network_cfg(cfgfile);
@@ -585,25 +584,6 @@
    cvWaitKey(0);
}
#ifdef GPU
void test_convolutional_layer()
{
    network net = parse_network_cfg("cfg/nist_conv.cfg");
    int size = get_network_input_size(net);
    float *in = calloc(size, sizeof(float));
    int i;
    for(i = 0; i < size; ++i) in[i] = rand_normal();
    convolutional_layer layer = *(convolutional_layer *)net.layers[0];
    int out_size = convolutional_out_height(layer)*convolutional_out_width(layer)*layer.batch;
    cuda_compare(layer.output_gpu, layer.output, out_size, "nothing");
    cuda_compare(layer.biases_gpu, layer.biases, layer.n, "biases");
    cuda_compare(layer.filters_gpu, layer.filters, layer.n*layer.size*layer.size*layer.c, "filters");
    bias_output(layer);
    bias_output_gpu(layer);
    cuda_compare(layer.output_gpu, layer.output, out_size, "biased output");
}
#endif
void test_correct_nist()
{
    network net = parse_network_cfg("cfg/nist_conv.cfg");