From d50ebc7fdf6543faab8c8b02d30730a9991f02b6 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Tue, 06 Dec 2016 11:41:18 +0000
Subject: [PATCH] Fixed command line examples

---
 src/blas_kernels.cu |  130 +++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 129 insertions(+), 1 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index ac537d8..d940176 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -140,6 +140,21 @@
 }
 
 
+__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
+{
+    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (index >= N) return;
+    
+    x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
+    //if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
+}
+
+extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
+{
+    adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t);
+    check_error(cudaPeekAtLastError());
+}
+
 __global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -312,6 +327,38 @@
     variance[i] *= scale;
 }
 
+__global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i >= N) return;
+    int in_index = i;
+    int in_w = i%w;
+    i = i/w;
+    int in_h = i%h;
+    i = i/h;
+    int in_c = i%c;
+    i = i/c;
+    int b = i%batch;
+
+    int out_c = c/(stride*stride);
+
+    int c2 = in_c % out_c;
+    int offset = in_c / out_c;
+    int w2 = in_w*stride + offset % stride;
+    int h2 = in_h*stride + offset / stride;
+    //printf("%d\n", offset);
+    int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
+
+   // printf("%d %d %d\n", w2, h2, c2);
+    //printf("%d %d\n", in_index, out_index);
+    //if(out_index >= N || out_index < 0) printf("bad bad bad \n");
+
+    if(forward) out[out_index] = x[in_index];
+    else out[in_index] = x[out_index];
+    //if(forward) out[1] = x[1];
+    //else out[0] = x[0];
+}
+
 __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX,  float *Y, int OFFY, int INCY)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -333,7 +380,15 @@
 __global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+    if(i < N) X[i*INCX] = fminf(ALPHA, fmaxf(-ALPHA, X[i*INCX]));
+}
+
+__global__ void supp_kernel(int N, float ALPHA, float *X, int INCX)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < N) {
+        if((X[i*INCX] * X[i*INCX]) < (ALPHA * ALPHA)) X[i*INCX] = 0;
+    }
 }
 
 __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
@@ -488,6 +543,37 @@
     check_error(cudaPeekAtLastError());
 }
 
+__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i >= N) return;
+    int in_s = i%spatial;
+    i = i/spatial;
+    int in_c = i%layers;
+    i = i/layers;
+    int b = i;
+
+    int i1 = b*layers*spatial + in_c*spatial + in_s;
+    int i2 = b*layers*spatial + in_s*layers +  in_c;
+
+    if (forward) out[i2] = x[i1];
+    else out[i1] = x[i2];
+}
+
+extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
+{
+    int size = spatial*batch*layers;
+    flatten_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, spatial, layers, batch, forward, out);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int size = w*h*c*batch;
+    reorg_kernel<<<cuda_gridsize(size), BLOCK>>>(size, x, w, h, c, batch, stride, forward, out);
+    check_error(cudaPeekAtLastError());
+}
+
 extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
 {
     mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
@@ -513,6 +599,12 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+    supp_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    check_error(cudaPeekAtLastError());
+}
+
 extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
 {
     fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
@@ -594,6 +686,7 @@
 }
 
 
+
 __global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -637,3 +730,38 @@
     mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
     check_error(cudaPeekAtLastError());
 }
+
+
+__device__ void softmax_device(int n, float *input, float temp, float *output)
+{
+    int i;
+    float sum = 0;
+    float largest = -INFINITY;
+    for(i = 0; i < n; ++i){
+        int val = input[i];
+        largest = (val>largest) ? val : largest;
+    }
+    for(i = 0; i < n; ++i){
+        float e = exp(input[i]/temp - largest/temp);
+        sum += e;
+        output[i] = e;
+    }
+    for(i = 0; i < n; ++i){
+        output[i] /= sum;
+    }
+}
+
+__global__ void softmax_kernel(int n, int offset, int batch, float *input, float temp, float *output)
+{
+    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(b >= batch) return;
+    softmax_device(n, input + b*offset, temp, output + b*offset);
+}
+
+extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output)
+{
+    int inputs = n;
+    int batch = groups;
+    softmax_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, offset, batch, input, temp, output);
+    check_error(cudaPeekAtLastError());
+}

--
Gitblit v1.10.0