From 1b5afb45838e603fa6780762eb8cc59246dc2d81 Mon Sep 17 00:00:00 2001
From: IlyaOvodov <b@ovdv.ru>
Date: Tue, 08 May 2018 11:09:35 +0000
Subject: [PATCH] Output improvements for detector results: When printing detector results, output was done in random order, obfuscating results for interpreting. Now: 1. Text output includes coordinates of rects in (left,right,top,bottom in pixels) along with label and score 2. Text output is sorted by rect lefts to simplify finding appropriate rects on image 3. If several class probs are > thresh for some detection, the most probable is written first and coordinates for others are not repeated 4. Rects are imprinted in image in order by their best class prob, so most probable rects are always on top and not overlayed by less probable ones 5. Most probable label for rect is always written first Also: 6. Message about low GPU memory include required amount

---
 src/blas_kernels.cu |  488 ++++++++++++++++++++++++++++++++++++++++++++++++++---
 1 files changed, 453 insertions(+), 35 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 8f05eb9..1edbbbd 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -1,6 +1,7 @@
 #include "cuda_runtime.h"
 #include "curand.h"
 #include "cublas_v2.h"
+#include <assert.h>
 
 extern "C" {
 #include "blas.h"
@@ -8,13 +9,173 @@
 #include "utils.h"
 }
 
+__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
+}
+
+void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    scale_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+    }
+}
+
+void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
+}
+
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    add_bias_kernel<<<dimGrid, dimBlock, 0, get_cuda_stream()>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
+    }
+}
+
+/*
+__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
+{
+    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    int f1 = index / n;
+    int f2 = index % n;
+    if (f2 <= f1) return;
+    
+    float sum = 0;
+    float norm1 = 0;
+    float norm2 = 0;
+    int b, i;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            sum += output[i1] * output[i2];
+            norm1 += output[i1] * output[i1];
+            norm2 += output[i2] * output[i2];
+        }
+    }
+    norm1 = sqrt(norm1);
+    norm2 = sqrt(norm2);
+    float norm = norm1 * norm2;
+    sum = sum / norm;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            delta[i1] += - scale * sum * output[i2] / norm;
+            delta[i2] += - scale * sum * output[i1] / norm;
+        }
+    }
+}
+
+void dot_error_gpu(layer l)
+{
+    dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
+    check_error(cudaPeekAtLastError());
+}
+*/
+
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+
+__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
+{
+    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (index >= N) return;
+    
+    x[index] = x[index] - (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps));
+    //if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
+}
+
+extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
+{
+    adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t)
+{
+	scal_ongpu(n, B1, m, 1);
+	scal_ongpu(n, B2, v, 1);
+	axpy_ongpu(n, -decay*batch, w, 1, d, 1);
+
+	axpy_ongpu(n, (1 - B1), d, 1, m, 1);
+	mul_ongpu(n, d, 1, d, 1);
+	axpy_ongpu(n, (1 - B2), d, 1, v, 1);
+
+	adam_gpu(n, w, m, v, B1, B2, rate, eps, t);
+	fill_ongpu(n, 0, d, 1);
+}
+
 __global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (index >= N) return;
     int f = (index/spatial)%filters;
     
-    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .00001f);
+    x[index] = (x[index] - mean[f])/(sqrtf(variance[f]) + .000001f);
 }
 
 __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -23,7 +184,7 @@
     if (index >= N) return;
     int f = (index/spatial)%filters;
     
-    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
+    delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
 }
 
 extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -45,7 +206,7 @@
             variance_delta[i] += delta[index]*(x[index] - mean[i]);
         }
     }
-    variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
+    variance_delta[i] *= -.5 * powf(variance[i] + .000001f, (float)(-3./2.));
 }
 
 __global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
@@ -76,13 +237,14 @@
             local[id] += (i+id < spatial) ? delta[index] : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         mean_delta[filter] = 0;
         for(i = 0; i < threads; ++i){
             mean_delta[filter] += local[i];
         }
-        mean_delta[filter] *= (-1./sqrt(variance[filter] + .00001f));
+        mean_delta[filter] *= (-1.F/sqrtf(variance[filter] + .000001f));
     }
 }
 
@@ -104,13 +266,14 @@
             local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         variance_delta[filter] = 0;
         for(i = 0; i < threads; ++i){
             variance_delta[filter] += local[i];
         }
-        variance_delta[filter] *= -.5 * pow(variance[filter] + .00001f, (float)(-3./2.));
+        variance_delta[filter] *= -.5 * powf(variance[filter] + .000001f, (float)(-3./2.));
     }
 }
 
@@ -127,7 +290,7 @@
             mean_delta[i] += delta[index];
         }
     }
-    mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
+    mean_delta[i] *= (-1.F/sqrtf(variance[i] + .000001f));
 }
 
 extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
@@ -150,7 +313,7 @@
 
 __global__ void  mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
 {
-    float scale = 1./(batch * spatial);
+    float scale = 1.F/(batch * spatial);
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (i >= filters) return;
     int j,k;
@@ -166,7 +329,7 @@
 
 __global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
 {
-    float scale = 1./(batch * spatial);
+    float scale = 1.F/(batch * spatial - 1);
     int j,k;
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (i >= filters) return;
@@ -174,12 +337,44 @@
     for(j = 0; j < batch; ++j){
         for(k = 0; k < spatial; ++k){
             int index = j*filters*spatial + i*spatial + k;
-            variance[i] += pow((x[index] - mean[i]), 2);
+            variance[i] += powf((x[index] - mean[i]), 2);
         }
     }
     variance[i] *= scale;
 }
 
+__global__ void reorg_kernel(int N, float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i >= N) return;
+    int in_index = i;
+    int in_w = i%w;
+    i = i/w;
+    int in_h = i%h;
+    i = i/h;
+    int in_c = i%c;
+    i = i/c;
+    int b = i%batch;
+
+    int out_c = c/(stride*stride);
+
+    int c2 = in_c % out_c;
+    int offset = in_c / out_c;
+    int w2 = in_w*stride + offset % stride;
+    int h2 = in_h*stride + offset / stride;
+    //printf("%d\n", offset);
+    int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
+
+   // printf("%d %d %d\n", w2, h2, c2);
+    //printf("%d %d\n", in_index, out_index);
+    //if(out_index >= N || out_index < 0) printf("bad bad bad \n");
+
+    if(forward) out[out_index] = x[in_index];
+    else out[in_index] = x[out_index];
+    //if(forward) out[1] = x[1];
+    //else out[0] = x[0];
+}
+
 __global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX,  float *Y, int OFFY, int INCY)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -189,7 +384,7 @@
 __global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
+    if(i < N) Y[i*INCY] = powf(X[i*INCX], ALPHA);
 }
 
 __global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
@@ -198,6 +393,20 @@
     if(i < N) X[i*INCX] = ALPHA;
 }
 
+__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < N) X[i*INCX] = fminf(ALPHA, fmaxf(-ALPHA, X[i*INCX]));
+}
+
+__global__ void supp_kernel(int N, float ALPHA, float *X, int INCX)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < N) {
+        if((X[i*INCX] * X[i*INCX]) < (ALPHA * ALPHA)) X[i*INCX] = 0;
+    }
+}
+
 __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -232,7 +441,7 @@
 extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     size_t N = batch*filters*spatial;
-    normalize_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, batch, filters, spatial);
+    normalize_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, x, mean, variance, batch, filters, spatial);
     check_error(cudaPeekAtLastError());
 }
 
@@ -253,6 +462,7 @@
             local[id] += (i+id < spatial) ? x[index] : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         mean[filter] = 0;
@@ -278,28 +488,29 @@
         for(i = 0; i < spatial; i += threads){
             int index = j*spatial*filters + filter*spatial + i + id;
 
-            local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
+            local[id] += (i+id < spatial) ? powf((x[index] - mean[filter]), 2) : 0;
         }
     }
+	__syncthreads();
 
     if(id == 0){
         variance[filter] = 0;
         for(i = 0; i < threads; ++i){
             variance[filter] += local[i];
         }
-        variance[filter] /= spatial * batch;
+        variance[filter] /= (spatial * batch - 1);
     }
 }
 
 extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
 {
-    fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
+    fast_mean_kernel<<<filters, BLOCK, 0, get_cuda_stream()>>>(x, batch, filters, spatial, mean);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
 {
-    fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
+    fast_variance_kernel<<<filters, BLOCK, 0, get_cuda_stream() >>>(x, mean, batch, filters, spatial, variance);
     check_error(cudaPeekAtLastError());
 }
 
@@ -323,13 +534,13 @@
 
 extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
 {
-    pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
+    pow_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, ALPHA, X, INCX, Y, INCY);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
 {
-    axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
+    axpy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
     check_error(cudaPeekAtLastError());
 }
 
@@ -346,13 +557,44 @@
 
 extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
 {
-    copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
+    copy_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void flatten_kernel(int N, float *x, int spatial, int layers, int batch, int forward, float *out)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i >= N) return;
+    int in_s = i%spatial;
+    i = i/spatial;
+    int in_c = i%layers;
+    i = i/layers;
+    int b = i;
+
+    int i1 = b*layers*spatial + in_c*spatial + in_s;
+    int i2 = b*layers*spatial + in_s*layers +  in_c;
+
+    if (forward) out[i2] = x[i1];
+    else out[i1] = x[i2];
+}
+
+extern "C" void flatten_ongpu(float *x, int spatial, int layers, int batch, int forward, float *out)
+{
+    int size = spatial*batch*layers;
+    flatten_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, spatial, layers, batch, forward, out);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int size = w*h*c*batch;
+    reorg_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, x, w, h, c, batch, stride, forward, out);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
 {
-    mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
+    mask_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream() >>>(N, X, mask_num, mask);
     check_error(cudaPeekAtLastError());
 }
 
@@ -362,38 +604,214 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+    constrain_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    check_error(cudaPeekAtLastError());
+}
+
+
 extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
 {
-    scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    scal_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void supp_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+    supp_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
     check_error(cudaPeekAtLastError());
 }
 
 extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
 {
-    fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    fill_kernel<<<cuda_gridsize(N), BLOCK, 0, get_cuda_stream()>>>(N, ALPHA, X, INCX);
     check_error(cudaPeekAtLastError());
 }
 
-__global__ void shortcut_kernel(int size, float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2, int min_c)
+__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
 {
     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (id >= size) return;
-    int i = id % (w/sample);
-    id /= (w/sample);
-    int j = id % (h/sample);
-    id /= (h/sample);
-    int k = id % min_c;
-    id /= min_c;
-    int b = id;
-    int out_index = i*sample + w*(j*sample + h*(k + c*b));
-    int add_index = b*w*stride/sample*h*stride/sample*c2 + i*stride + w*stride/sample*(j*stride + h*stride/sample*k);
+    int i = id % minw;
+    id /= minw;
+    int j = id % minh;
+    id /= minh;
+    int k = id % minc;
+    id /= minc;
+    int b = id % batch;
+
+    int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
+    int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
     out[out_index] += add[add_index];
 }
 
-extern "C" void shortcut_gpu(float *out, int w, int h, int c, int batch, int sample, float *add, int stride, int c2)
+extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
 {
-    int min_c = (c < c2) ? c : c2;
-    int size = batch * w/sample * h/sample * min_c;
-    shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, out, w, h, c, batch, sample, add, stride, c2, min_c);
+    int minw = (w1 < w2) ? w1 : w2;
+    int minh = (h1 < h2) ? h1 : h2;
+    int minc = (c1 < c2) ? c1 : c2;
+
+    int stride = w1/w2;
+    int sample = w2/w1;
+    assert(stride == h1/h2);
+    assert(sample == h2/h1);
+    if(stride < 1) stride = 1;
+    if(sample < 1) sample = 1;
+
+    int size = batch * minw * minh * minc;
+    shortcut_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
     check_error(cudaPeekAtLastError());
 }
+
+__global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        float diff = truth[i] - pred[i];
+        float abs_val = abs(diff);
+        if(abs_val < 1) {
+            error[i] = diff * diff;
+            delta[i] = diff;
+        }
+        else {
+            error[i] = 2*abs_val - 1;
+            delta[i] = (diff < 0) ? -1 : 1;
+        }
+    }
+}
+
+extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+    smooth_l1_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        float diff = truth[i] - pred[i];
+        error[i] = diff * diff; //I know this is technically wrong, deal with it.
+        delta[i] = diff;
+    }
+}
+
+extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+    l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+    check_error(cudaPeekAtLastError());
+}
+
+
+
+__global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] = s[i]*a[i] + (1-s[i])*(b ? b[i] : 0);
+    }
+}
+
+extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
+{
+    weighted_sum_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, c);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        if(da) da[i] += dc[i] * s[i];
+        db[i] += dc[i] * (1-s[i]);
+        ds[i] += dc[i] * a[i] + dc[i] * -b[i];
+    }
+}
+
+extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
+{
+    weighted_delta_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, da, db, ds, dc);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] += a[i]*b[i];
+    }
+}
+
+extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
+{
+    mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
+    check_error(cudaPeekAtLastError());
+}
+
+
+__device__ void softmax_device(int n, float *input, float temp, float *output)
+{
+    int i;
+    float sum = 0;
+    float largest = -INFINITY;
+    for(i = 0; i < n; ++i){
+        int val = input[i];
+        largest = (val>largest) ? val : largest;
+    }
+    for(i = 0; i < n; ++i){
+        float e = exp(input[i]/temp - largest/temp);
+        sum += e;
+        output[i] = e;
+    }
+    for(i = 0; i < n; ++i){
+        output[i] /= sum;
+    }
+}
+
+__global__ void softmax_kernel(int n, int offset, int batch, float *input, float temp, float *output)
+{
+    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(b >= batch) return;
+    softmax_device(n, input + b*offset, temp, output + b*offset);
+}
+
+extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output)
+{
+    int inputs = n;
+    int batch = groups;
+    softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
+    check_error(cudaPeekAtLastError());
+}
+
+
+__global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
+{
+	size_t i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+	if (i >= N) return;
+	int out_index = i;
+	int out_w = i % (w*stride);
+	i = i / (w*stride);
+	int out_h = i % (h*stride);
+	i = i / (h*stride);
+	int out_c = i%c;
+	i = i / c;
+	int b = i%batch;
+
+	int in_w = out_w / stride;
+	int in_h = out_h / stride;
+	int in_c = out_c;
+
+	int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w;
+
+
+	if (forward) out[out_index] += scale * x[in_index];
+	else atomicAdd(x + in_index, scale * out[out_index]);
+}
+
+extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
+{
+	size_t size = w*h*c*batch*stride*stride;
+	upsample_kernel << <cuda_gridsize(size), BLOCK >> >(size, in, w, h, c, batch, stride, forward, scale, out);
+	check_error(cudaPeekAtLastError());
+}
\ No newline at end of file

--
Gitblit v1.10.0