From 8f1b4e0962857d402f9d017fcbf387ef0eceb7c4 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 01 Sep 2016 23:48:41 +0000
Subject: [PATCH] updates and things

---
 src/im2col_kernels.cu |   95 ++---------------------------------------------
 1 files changed, 4 insertions(+), 91 deletions(-)

diff --git a/src/im2col_kernels.cu b/src/im2col_kernels.cu
index c2dd780..d42d600 100644
--- a/src/im2col_kernels.cu
+++ b/src/im2col_kernels.cu
@@ -33,8 +33,12 @@
             for (int j = 0; j < ksize; ++j) {
                 int h = h_in + i;
                 int w = w_in + j;
+
                 *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ?
                     data_im_ptr[i * width + j] : 0;
+
+                //*data_col_ptr = data_im_ptr[ii * width + jj];
+
                 data_col_ptr += height_col * width_col;
             }
         }
@@ -46,7 +50,6 @@
          int ksize, int stride, int pad, float *data_col){
     // We are going to launch channels * height_col * width_col kernels, each
     // kernel responsible for copying a single-channel grid.
-    pad = pad ? ksize/2 : 0;
     int height_col = (height + 2 * pad - ksize) / stride + 1;
     int width_col = (width + 2 * pad - ksize) / stride + 1;
     int num_kernels = channels * height_col * width_col;
@@ -56,93 +59,3 @@
                 stride, height_col,
                 width_col, data_col);
 }
-/*
-   __global__ void im2col_pad_kernel(float *im,
-   int channels,  int height,  int width,
-   int ksize,  int stride, float *data_col)
-   {
-   int c,h,w;
-   int height_col = 1 + (height-1) / stride;
-   int width_col = 1 + (width-1) / stride;
-   int channels_col = channels * ksize * ksize;
-
-   int pad = ksize/2;
-
-   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-   int col_size = height_col*width_col*channels_col;
-   if (id >= col_size) return;
-
-   int col_index = id;
-   w = id % width_col;
-   id /= width_col;
-   h = id % height_col;
-   id /= height_col;
-   c = id % channels_col;
-   id /= channels_col;
-
-   int w_offset = c % ksize;
-   int h_offset = (c / ksize) % ksize;
-   int im_channel = c / ksize / ksize;
-   int im_row = h_offset + h * stride - pad;
-   int im_col = w_offset + w * stride - pad;
-
-   int im_index = im_col + width*(im_row + height*im_channel);
-   float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
-
-   data_col[col_index] = val;
-   }
-
-   __global__ void im2col_nopad_kernel(float *im,
-   int channels,  int height,  int width,
-   int ksize,  int stride, float *data_col)
-   {
-   int c,h,w;
-   int height_col = (height - ksize) / stride + 1;
-   int width_col = (width - ksize) / stride + 1;
-   int channels_col = channels * ksize * ksize;
-
-   int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-   int col_size = height_col*width_col*channels_col;
-   if (id >= col_size) return;
-
-   int col_index = id;
-   w = id % width_col;
-   id /= width_col;
-   h = id % height_col;
-   id /= height_col;
-   c = id % channels_col;
-   id /= channels_col;
-
-   int w_offset = c % ksize;
-   int h_offset = (c / ksize) % ksize;
-   int im_channel = c / ksize / ksize;
-   int im_row = h_offset + h * stride;
-   int im_col = w_offset + w * stride;
-
-   int im_index = im_col + width*(im_row + height*im_channel);
-   float val = (im_row < 0 || im_col < 0 || im_row >= height || im_col >= width) ? 0 : im[im_index];
-
-   data_col[col_index] = val;
-   }
-
-   extern "C" void im2col_ongpu(float *im,
-   int channels,  int height,  int width,
-int ksize,  int stride,  int pad, float *data_col)
-{
-
-    int height_col = (height - ksize) / stride + 1;
-    int width_col = (width - ksize) / stride + 1;
-    int channels_col = channels * ksize * ksize;
-
-    if (pad){
-        height_col = 1 + (height-1) / stride;
-        width_col = 1 + (width-1) / stride;
-    }
-
-    size_t n = channels_col*height_col*width_col;
-
-    if(pad)im2col_pad_kernel<<<cuda_gridsize(n),BLOCK>>>(im,  channels, height, width, ksize, stride, data_col);
-    else im2col_nopad_kernel<<<cuda_gridsize(n),BLOCK>>>(im,  channels, height, width, ksize, stride, data_col);
-    check_error(cudaPeekAtLastError());
-}
-*/

--
Gitblit v1.10.0