From 73f7aacf35ec9b1d0f9de9ddf38af0889f213e99 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Sep 2016 18:34:49 +0000
Subject: [PATCH] better multigpu
---
src/blas_kernels.cu | 56 +++++++++++++++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 55 insertions(+), 1 deletions(-)
diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index ac537d8..0391e2e 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -312,6 +312,38 @@
variance[i] *= scale;
}
+__global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i >= N) return;
+ int in_index = i;
+ int in_w = i%w;
+ i = i/w;
+ int in_h = i%h;
+ i = i/h;
+ int in_c = i%c;
+ i = i/c;
+ int b = i%batch;
+
+ int out_c = c/(stride*stride);
+
+ int c2 = in_c % out_c;
+ int offset = in_c / out_c;
+ int w2 = in_w*stride + offset % stride;
+ int h2 = in_h*stride + offset / stride;
+ //printf("%d\n", offset);
+ int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
+
+ // printf("%d %d %d\n", w2, h2, c2);
+ //printf("%d %d\n", in_index, out_index);
+ //if(out_index >= N || out_index < 0) printf("bad bad bad \n");
+
+ if(forward) out[out_index] = x[in_index];
+ else out[in_index] = x[out_index];
+ //if(forward) out[1] = x[1];
+ //else out[0] = x[0];
+}
+
__global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -333,7 +365,15 @@
__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
- if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+ if(i < N) X[i*INCX] = fminf(ALPHA, fmaxf(-ALPHA, X[i*INCX]));
+}
+
+__global__ void supp_kernel(int N, float ALPHA, float *X, int INCX)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < N) {
+ if((X[i*INCX] * X[i*INCX]) < (ALPHA * ALPHA)) X[i*INCX] = 0;
+ }
}
__global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
@@ -488,6 +528,13 @@
check_error(cudaPeekAtLastError());
}
+extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+ int size = w*h*c*batch;
+ reorg_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, w, h, c, batch, stride, forward, out);
+ check_error(cudaPeekAtLastError());
+}
+
extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
{
mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
@@ -513,6 +560,12 @@
check_error(cudaPeekAtLastError());
}
+extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+ supp_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+ check_error(cudaPeekAtLastError());
+}
+
extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
{
fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
@@ -594,6 +647,7 @@
}
+
__global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
--
Gitblit v1.10.0