Joseph Redmon
2014-11-21 e36182cd8c5dd5c6d0aa1f77cf5cdca87e8bb1f0
src/im2col.c
@@ -51,12 +51,23 @@
#include "opencl.h"
#include <math.h>
cl_kernel get_im2col_kernel()
cl_kernel get_im2col_pad_kernel()
{
    static int init = 0;
    static cl_kernel im2col_kernel;
    if(!init){
        im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0);
        im2col_kernel = get_kernel("src/im2col.cl", "im2col_pad", 0);
        init = 1;
    }
    return im2col_kernel;
}
cl_kernel get_im2col_nopad_kernel()
{
    static int init = 0;
    static cl_kernel im2col_kernel;
    if(!init){
        im2col_kernel = get_kernel("src/im2col.cl", "im2col_nopad", 0);
        init = 1;
    }
    return im2col_kernel;
@@ -68,32 +79,34 @@
         int ksize,  int stride,  int pad, cl_mem data_col)
{
    cl_setup();
    cl_kernel im2col_kernel = get_im2col_kernel();
    cl_command_queue queue = cl.queue;
    cl_uint i = 0;
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(pad), (void*) &pad);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col);
    check_error(cl);
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    cl_kernel kernel = get_im2col_nopad_kernel();
    if (pad){
        height_col = 1 + (height-1) / stride;
        width_col = 1 + (width-1) / stride;
        kernel = get_im2col_pad_kernel();
    }
    cl_command_queue queue = cl.queue;
    cl_uint i = 0;
    cl.error = clSetKernelArg(kernel, i++, sizeof(data_im), (void*) &data_im);
    cl.error = clSetKernelArg(kernel, i++, sizeof(batch), (void*) &batch);
    cl.error = clSetKernelArg(kernel, i++, sizeof(channels), (void*) &channels);
    cl.error = clSetKernelArg(kernel, i++, sizeof(height), (void*) &height);
    cl.error = clSetKernelArg(kernel, i++, sizeof(width), (void*) &width);
    cl.error = clSetKernelArg(kernel, i++, sizeof(ksize), (void*) &ksize);
    cl.error = clSetKernelArg(kernel, i++, sizeof(stride), (void*) &stride);
    cl.error = clSetKernelArg(kernel, i++, sizeof(data_col), (void*) &data_col);
    check_error(cl);
    size_t global_size = batch*channels_col*height_col*width_col;
    clEnqueueNDRangeKernel(queue, im2col_kernel, 1, 0,
    cl.error = clEnqueueNDRangeKernel(queue, kernel, 1, 0,
            &global_size, 0, 0, 0, 0);
    check_error(cl);
}