From 809f924db2823b9e1eaf3efb9370380edc1f76ed Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 23 Jan 2015 00:38:24 +0000
Subject: [PATCH] CUDA so fast

---
 src/col2im_kernels.cu |   29 ++++++++++++++++++++++++-----
 1 files changed, 24 insertions(+), 5 deletions(-)

diff --git a/src/col2im.cl b/src/col2im_kernels.cu
similarity index 60%
rename from src/col2im.cl
rename to src/col2im_kernels.cu
index 617e818..73de9b7 100644
--- a/src/col2im.cl
+++ b/src/col2im_kernels.cu
@@ -1,6 +1,11 @@
-__kernel void col2im(__global float *data_col, int offset,
+extern "C" {
+#include "col2im.h"
+#include "cuda.h"
+}
+
+__global__ void col2im_kernel(float *data_col, int offset,
         int channels, int height, int width,
-        int ksize, int stride, int pad, __global float *data_im)
+        int ksize, int stride, int pad, float *data_im)
 {
 
     int height_col = (height - ksize) / stride + 1;
@@ -11,7 +16,9 @@
         pad = ksize/2;
     }
 
-    int id = get_global_id(0);
+    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(id >= channels*height*width) return;
+
     int index = id;
     int w = id%width + pad;
     id /= width;
@@ -25,8 +32,8 @@
     int h_start = (h-ksize+stride)/stride;
     int h_end = h/stride + 1;
 
-    int rows = channels * ksize * ksize;
-    int cols = height_col*width_col;
+    // int rows = channels * ksize * ksize;
+    // int cols = height_col*width_col;
     int col_offset = (c*ksize*ksize + h * ksize + w)*height_col*width_col;
     int h_coeff = (1-stride*ksize*height_col)*width_col;
     int w_coeff = 1-stride*height_col*width_col;
@@ -41,3 +48,15 @@
     }
     data_im[index+offset] = val;
 }
+
+
+extern "C" void col2im_ongpu(float *data_col, int offset,
+        int channels,  int height,  int width,
+        int ksize,  int stride,  int pad, float *data_im)
+{
+
+    size_t n = channels*height*width;
+
+    col2im_kernel<<<cuda_gridsize(n), BLOCK>>>(data_col, offset, channels, height, width, ksize, stride, pad, data_im);
+    check_error(cudaPeekAtLastError());
+}

--
Gitblit v1.10.0