Joseph Redmon
2015-03-21 9d418102f4a44b2e2c437dec945189a646cbf3a4
using caffe's im2col, it's so much better\!
5 files modified
365 ■■■■■ changed files
Makefile 2 ●●● patch | view | raw | blame | history
src/col2im_kernels.cu 144 ●●●●● patch | view | raw | blame | history
src/convolutional_kernels.cu 34 ●●●● patch | view | raw | blame | history
src/im2col_kernels.cu 183 ●●●●● patch | view | raw | blame | history
src/imagenet.c 2 ●●● patch | view | raw | blame | history
Makefile
@@ -8,7 +8,7 @@
CC=gcc
NVCC=nvcc
OPTS=-O0
OPTS=-O3
LDFLAGS=`pkg-config --libs opencv` -lm -pthread -lstdc++
COMMON=`pkg-config --cflags opencv` -I/usr/local/cuda/include/
CFLAGS=-Wall -Wfatal-errors
src/col2im_kernels.cu
@@ -3,60 +3,112 @@
#include "cuda.h"
}
__global__ void col2im_kernel(float *data_col,
        int channels, int height, int width,
        int ksize, int stride, int pad, float *data_im)
{
// src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu
// You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    if (pad){
        height_col = 1 + (height-1) / stride;
        width_col = 1 + (width-1) / stride;
        pad = ksize/2;
    }
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(id >= channels*height*width) return;
    int index = id;
    int w = id%width + pad;
    id /= width;
    int h = id%height + pad;
    id /= height;
    int c = id%channels;
    int w_start = (w-ksize+stride)/stride;
    int w_end = w/stride + 1;
    int h_start = (h-ksize+stride)/stride;
    int h_end = h/stride + 1;
    // int rows = channels * ksize * ksize;
    // int cols = height_col*width_col;
    int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col;
    int h_coeff = (1-stride*ksize*height_col)*width_col;
    int w_coeff = 1-stride*height_col*width_col;
    float val = 0;
    int h_col, w_col;
    for(h_col = h_start; h_col < h_end; ++h_col){
        for(w_col = w_start; w_col < w_end; ++w_col){
            int col_index = col_offset +h_col*h_coeff + w_col*w_coeff;
            float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index];
            val += part;
__global__ void col2im_gpu_kernel(const int n, const float* data_col,
        const int height, const int width, const int ksize,
        const int pad,
        const int stride,
        const int height_col, const int width_col,
        float *data_im) {
    int index = blockIdx.x*blockDim.x+threadIdx.x;
    for(; index < n; index += blockDim.x*gridDim.x){
        float val = 0;
        int w = index % width + pad;
        int h = (index / width) % height + pad;
        int c = index / (width * height);
        // compute the start and end of the output
        int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1;
        int w_col_end = min(w / stride + 1, width_col);
        int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
        int h_col_end = min(h / stride + 1, height_col);
        // equivalent implementation
        int offset =
            (c * ksize * ksize + h * ksize + w) * height_col * width_col;
        int coeff_h_col = (1 - stride * ksize * height_col) * width_col;
        int coeff_w_col = (1 - stride * height_col * width_col);
        for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
            for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
                val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
            }
        }
        data_im[index] = val;
    }
    data_im[index] = val;
}
void col2im_ongpu(float *im,
        int channels, int height, int width,
        int ksize, int stride, int pad, float *data_col){
    // We are going to launch channels * height_col * width_col kernels, each
    // kernel responsible for copying a single-channel grid.
    pad = pad ? ksize/2 : 0;
    int height_col = (height + 2 * pad - ksize) / stride + 1;
    int width_col = (width + 2 * pad - ksize) / stride + 1;
    int num_kernels = channels * height * width;
    col2im_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK,
        BLOCK>>>(
                num_kernels, data_col, height, width, ksize, pad,
                stride, height_col,
                width_col, im);
}
/*
   __global__ void col2im_kernel(float *data_col,
   int channels, int height, int width,
   int ksize, int stride, int pad, float *data_im)
   {
   int height_col = (height - ksize) / stride + 1;
   int width_col = (width - ksize) / stride + 1;
   if (pad){
   height_col = 1 + (height-1) / stride;
   width_col = 1 + (width-1) / stride;
   pad = ksize/2;
   }
   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
   if(id >= channels*height*width) return;
   int index = id;
   int w = id%width + pad;
   id /= width;
   int h = id%height + pad;
   id /= height;
   int c = id%channels;
   int w_start = (w-ksize+stride)/stride;
   int w_end = w/stride + 1;
   int h_start = (h-ksize+stride)/stride;
   int h_end = h/stride + 1;
// int rows = channels * ksize * ksize;
// int cols = height_col*width_col;
int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col;
int h_coeff = (1-stride*ksize*height_col)*width_col;
int w_coeff = 1-stride*height_col*width_col;
float val = 0;
int h_col, w_col;
for(h_col = h_start; h_col < h_end; ++h_col){
for(w_col = w_start; w_col < w_end; ++w_col){
int col_index = col_offset +h_col*h_coeff + w_col*w_coeff;
float part = (w_col < 0 || h_col < 0 || h_col >= height_col || w_col >= width_col) ? 0 : data_col[col_index];
val += part;
}
}
data_im[index] = val;
}
extern "C" void col2im_ongpu(float *data_col,
        int channels,  int height,  int width,
        int ksize,  int stride,  int pad, float *data_im)
int channels,  int height,  int width,
int ksize,  int stride,  int pad, float *data_im)
{
    size_t n = channels*height*width;
size_t n = channels*height*width;
    col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, channels, height, width, ksize, stride, pad, data_im);
    check_error(cudaPeekAtLastError());
col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, channels, height, width, ksize, stride, pad, data_im);
check_error(cudaPeekAtLastError());
}
 */
src/convolutional_kernels.cu
@@ -56,7 +56,7 @@
extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
{
clock_t time = clock();
//clock_t time = clock();
    int i;
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
@@ -64,31 +64,31 @@
        convolutional_out_width(layer);
    bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
cudaDeviceSynchronize();
printf("bias %f\n", sec(clock() - time));
time = clock();
//cudaDeviceSynchronize();
//printf("bias %f\n", sec(clock() - time));
//time = clock();
float imt=0;
float gemt = 0;
//float imt=0;
//float gemt = 0;
    for(i = 0; i < layer.batch; ++i){
time = clock();
//time = clock();
        im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
cudaDeviceSynchronize();
imt += sec(clock()-time);
time = clock();
//cudaDeviceSynchronize();
//imt += sec(clock()-time);
//time = clock();
        float * a = layer.filters_gpu;
        float * b = layer.col_image_gpu;
        float * c = layer.output_gpu;
        gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
cudaDeviceSynchronize();
gemt += sec(clock()-time);
time = clock();
//cudaDeviceSynchronize();
//gemt += sec(clock()-time);
//time = clock();
    }
    activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
cudaDeviceSynchronize();
printf("activate %f\n", sec(clock() - time));
printf("im2col %f\n", imt);
printf("gemm %f\n", gemt);
//cudaDeviceSynchronize();
//printf("activate %f\n", sec(clock() - time));
//printf("im2col %f\n", imt);
//printf("gemm %f\n", gemt);
}
extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
src/im2col_kernels.cu
@@ -3,77 +3,127 @@
#include "cuda.h"
}
__global__ void im2col_pad_kernel(float *im,
     int channels,  int height,  int width,
     int ksize,  int stride, float *data_col)
{
    int c,h,w;
    int height_col = 1 + (height-1) / stride;
    int width_col = 1 + (width-1) / stride;
    int channels_col = channels * ksize * ksize;
// src: https://github.com/BVLC/caffe/blob/master/src/caffe/util/im2col.cu
// You may also want to read: https://github.com/BVLC/caffe/blob/master/LICENSE
    int pad = ksize/2;
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    int col_size = height_col*width_col*channels_col;
    if (id >= col_size) return;
    int col_index = id;
    w = id % width_col;
    id /= width_col;
    h = id % height_col;
    id /= height_col;
    c = id % channels_col;
    id /= channels_col;
    int w_offset = c % ksize;
    int h_offset = (c / ksize) % ksize;
    int im_channel = c / ksize / ksize;
    int im_row = h_offset + h * stride - pad;
    int im_col = w_offset + w * stride - pad;
    int im_index = im_col + width*(im_row + height*im_channel);
    float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
    data_col[col_index] = val;
__global__ void im2col_gpu_kernel(const int n, const float* data_im,
        const int height, const int width, const int ksize,
        const int pad,
        const int stride,
        const int height_col, const int width_col,
        float *data_col) {
    int index = blockIdx.x*blockDim.x+threadIdx.x;
    for(; index < n; index += blockDim.x*gridDim.x){
        int w_out = index % width_col;
        int h_index = index / width_col;
        int h_out = h_index % height_col;
        int channel_in = h_index / height_col;
        int channel_out = channel_in * ksize * ksize;
        int h_in = h_out * stride - pad;
        int w_in = w_out * stride - pad;
        float* data_col_ptr = data_col;
        data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out;
        const float* data_im_ptr = data_im;
        data_im_ptr += (channel_in * height + h_in) * width + w_in;
        for (int i = 0; i < ksize; ++i) {
            for (int j = 0; j < ksize; ++j) {
                int h = h_in + i;
                int w = w_in + j;
                *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
                    data_im_ptr[i * width + j] : 0;
                data_col_ptr += height_col * width_col;
            }
        }
    }
}
__global__ void im2col_nopad_kernel(float *im,
        int channels,  int height,  int width,
        int ksize,  int stride, float *data_col)
{
    int c,h,w;
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    int col_size = height_col*width_col*channels_col;
    if (id >= col_size) return;
    int col_index = id;
    w = id % width_col;
    id /= width_col;
    h = id % height_col;
    id /= height_col;
    c = id % channels_col;
    id /= channels_col;
    int w_offset = c % ksize;
    int h_offset = (c / ksize) % ksize;
    int im_channel = c / ksize / ksize;
    int im_row = h_offset + h * stride;
    int im_col = w_offset + w * stride;
    int im_index = im_col + width*(im_row + height*im_channel);
    float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
    data_col[col_index] = val;
void im2col_ongpu(float *im,
         int channels, int height, int width,
         int ksize, int stride, int pad, float *data_col){
    // We are going to launch channels * height_col * width_col kernels, each
    // kernel responsible for copying a single-channel grid.
    pad = pad ? ksize/2 : 0;
    int height_col = (height + 2 * pad - ksize) / stride + 1;
    int width_col = (width + 2 * pad - ksize) / stride + 1;
    int num_kernels = channels * height_col * width_col;
    im2col_gpu_kernel<<<(num_kernels+BLOCK-1)/BLOCK,
        BLOCK>>>(
                num_kernels, im, height, width, ksize, pad,
                stride, height_col,
                width_col, data_col);
}
/*
   __global__ void im2col_pad_kernel(float *im,
   int channels,  int height,  int width,
   int ksize,  int stride, float *data_col)
   {
   int c,h,w;
   int height_col = 1 + (height-1) / stride;
   int width_col = 1 + (width-1) / stride;
   int channels_col = channels * ksize * ksize;
extern "C" void im2col_ongpu(float *im,
        int channels,  int height,  int width,
        int ksize,  int stride,  int pad, float *data_col)
   int pad = ksize/2;
   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
   int col_size = height_col*width_col*channels_col;
   if (id >= col_size) return;
   int col_index = id;
   w = id % width_col;
   id /= width_col;
   h = id % height_col;
   id /= height_col;
   c = id % channels_col;
   id /= channels_col;
   int w_offset = c % ksize;
   int h_offset = (c / ksize) % ksize;
   int im_channel = c / ksize / ksize;
   int im_row = h_offset + h * stride - pad;
   int im_col = w_offset + w * stride - pad;
   int im_index = im_col + width*(im_row + height*im_channel);
   float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
   data_col[col_index] = val;
   }
   __global__ void im2col_nopad_kernel(float *im,
   int channels,  int height,  int width,
   int ksize,  int stride, float *data_col)
   {
   int c,h,w;
   int height_col = (height - ksize) / stride + 1;
   int width_col = (width - ksize) / stride + 1;
   int channels_col = channels * ksize * ksize;
   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
   int col_size = height_col*width_col*channels_col;
   if (id >= col_size) return;
   int col_index = id;
   w = id % width_col;
   id /= width_col;
   h = id % height_col;
   id /= height_col;
   c = id % channels_col;
   id /= channels_col;
   int w_offset = c % ksize;
   int h_offset = (c / ksize) % ksize;
   int im_channel = c / ksize / ksize;
   int im_row = h_offset + h * stride;
   int im_col = w_offset + w * stride;
   int im_index = im_col + width*(im_row + height*im_channel);
   float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
   data_col[col_index] = val;
   }
   extern "C" void im2col_ongpu(float *im,
   int channels,  int height,  int width,
int ksize,  int stride,  int pad, float *data_col)
{
    int height_col = (height - ksize) / stride + 1;
@@ -91,3 +141,4 @@
    else im2col_nopad_kernel<<<cuda_gridsize(n),BLOCK>>>(im,  channels, height, width, ksize, stride, data_col);
    check_error(cudaPeekAtLastError());
}
*/
src/imagenet.c
@@ -13,7 +13,7 @@
        load_weights(&net, weightfile);
    }
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int imgs = 128;
    int imgs = 1024;
    int i = net.seen/imgs;
    char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
    list *plist = get_paths("/data/imagenet/cls.train.list");