From 845ab7579685b6702c92c1088ec11e71bde51f3c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 05 Aug 2016 22:27:07 +0000
Subject: [PATCH] some more stuff

---
 src/blas_kernels.cu |   39 +++++++++++++++++++++++++++++++++++++++
 1 files changed, 39 insertions(+), 0 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index ac537d8..3f7f1f9 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -312,6 +312,38 @@
     variance[i] *= scale;
 }
 
+__global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i >= N) return;
+    int in_index = i;
+    int in_w = i%w;
+    i = i/w;
+    int in_h = i%h;
+    i = i/h;
+    int in_c = i%c;
+    i = i/c;
+    int b = i%batch;
+
+    int out_c = c/(stride*stride);
+
+    int c2 = in_c % out_c;
+    int offset = in_c / out_c;
+    int w2 = in_w*stride + offset % stride;
+    int h2 = in_h*stride + offset / stride;
+    //printf("%d\n", offset);
+    int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
+
+   // printf("%d %d %d\n", w2, h2, c2);
+    //printf("%d %d\n", in_index, out_index);
+    //if(out_index >= N || out_index < 0) printf("bad bad bad \n");
+
+    if(forward) out[out_index] = x[in_index];
+    else out[in_index] = x[out_index];
+    //if(forward) out[1] = x[1];
+    //else out[0] = x[0];
+}
+
 __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX,  float *Y, int OFFY, int INCY)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -488,6 +520,13 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int size = w*h*c*batch;
+    reorg_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, w, h, c, batch, stride, forward, out);
+    check_error(cudaPeekAtLastError());
+}
+
 extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
 {
     mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);

--
Gitblit v1.10.0