From 1c0fd9bb4726f28b5ccf4491b8d108b00c884ec3 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 30 Oct 2014 06:26:41 +0000
Subject: [PATCH] im2col slightly faster

---
 src/im2col.c |   47 ++++++++++++++++++++++++++++++-----------------
 1 files changed, 30 insertions(+), 17 deletions(-)

diff --git a/src/im2col.c b/src/im2col.c
index b743e34..bfaa54c 100644
--- a/src/im2col.c
+++ b/src/im2col.c
@@ -51,12 +51,23 @@
 #include "opencl.h"
 #include <math.h>
 
-cl_kernel get_im2col_kernel()
+cl_kernel get_im2col_pad_kernel()
 {
     static int init = 0;
     static cl_kernel im2col_kernel;
     if(!init){
-        im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0);
+        im2col_kernel = get_kernel("src/im2col.cl", "im2col_pad", 0);
+        init = 1;
+    }
+    return im2col_kernel;
+}
+
+cl_kernel get_im2col_nopad_kernel()
+{
+    static int init = 0;
+    static cl_kernel im2col_kernel;
+    if(!init){
+        im2col_kernel = get_kernel("src/im2col.cl", "im2col_nopad", 0);
         init = 1;
     }
     return im2col_kernel;
@@ -68,32 +79,34 @@
          int ksize,  int stride,  int pad, cl_mem data_col)
 {
     cl_setup();
-    cl_kernel im2col_kernel = get_im2col_kernel();
-    cl_command_queue queue = cl.queue;
-
-    cl_uint i = 0;
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(pad), (void*) &pad);
-    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col);
-    check_error(cl);
 
     int height_col = (height - ksize) / stride + 1;
     int width_col = (width - ksize) / stride + 1;
     int channels_col = channels * ksize * ksize;
+    cl_kernel kernel = get_im2col_nopad_kernel();
+
     if (pad){
         height_col = 1 + (height-1) / stride;
         width_col = 1 + (width-1) / stride;
+        kernel = get_im2col_pad_kernel();
     }
 
+    cl_command_queue queue = cl.queue;
+
+    cl_uint i = 0;
+    cl.error = clSetKernelArg(kernel, i++, sizeof(data_im), (void*) &data_im);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(batch), (void*) &batch);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(channels), (void*) &channels);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(height), (void*) &height);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(width), (void*) &width);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(ksize), (void*) &ksize);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(stride), (void*) &stride);
+    cl.error = clSetKernelArg(kernel, i++, sizeof(data_col), (void*) &data_col);
+    check_error(cl);
+
     size_t global_size = batch*channels_col*height_col*width_col;
 
-    clEnqueueNDRangeKernel(queue, im2col_kernel, 1, 0,
+    clEnqueueNDRangeKernel(queue, kernel, 1, 0,
             &global_size, 0, 0, 0, 0);
     check_error(cl);
 }

--
Gitblit v1.10.0