From a392bbd0c957a00e3782c96e7ced84a29ff9dd88 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 15 Mar 2016 05:33:02 +0000
Subject: [PATCH] Play along w/ alphago

---
 src/blas_kernels.cu |  291 +++++++++++++++++++++++++++++++++++++++++----------------
 1 files changed, 208 insertions(+), 83 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 4da31d1..98366f8 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -1,6 +1,7 @@
 #include "cuda_runtime.h"
 #include "curand.h"
 #include "cublas_v2.h"
+#include <assert.h>
 
 extern "C" {
 #include "blas.h"
@@ -14,7 +15,7 @@
     if (index >= N) return;
     int f = (index/spatial)%filters;
     
-    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .00001f);
+    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .000001f);
 }
 
 __global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -23,7 +24,7 @@
     if (index >= N) return;
     int f = (index/spatial)%filters;
     
-    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
+    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
 }
 
 extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -45,29 +46,7 @@
             variance_delta[i] += delta[index]*(x[index] - mean[i]);
         }
     }
-    variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
-}
-
-__global__ void spatial_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta)
-{
-    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if (i >= batch*filters) return;
-    int f = i%filters;
-    int b = i/filters;
-
-    int k;
-    spatial_variance_delta[i] = 0;
-    for (k = 0; k < spatial; ++k) {
-        int index = b*filters*spatial + f*spatial + k;
-        spatial_variance_delta[i] += delta[index]*(x[index] - mean[f]);
-    }
-    spatial_variance_delta[i] *= -.5 * pow(variance[f] + .00001f, (float)(-3./2.));
-}
-
-extern "C" void variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
-{
-    variance_delta_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
-    check_error(cudaPeekAtLastError());
+    variance_delta[i] *= -.5 * pow(variance[i] + .000001f, (float)(-3./2.));
 }
 
 __global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
@@ -81,38 +60,62 @@
     }
 }
 
-extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta, float *variance_delta)
+__global__ void fast_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
 {
-    spatial_variance_delta_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, spatial_variance_delta);
-    check_error(cudaPeekAtLastError());
-    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_variance_delta, batch, filters, variance_delta);
-    check_error(cudaPeekAtLastError());
-}
+    const int threads = BLOCK;
+    __shared__ float local[threads];
 
-__global__ void spatial_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta)
-{
-    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if (i >= batch*filters) return;
-    int f = i%filters;
-    int b = i/filters;
+    int id = threadIdx.x;
+    local[id] = 0;
 
-    int k;
-    spatial_mean_delta[i] = 0;
-    for (k = 0; k < spatial; ++k) {
-        int index = b*filters*spatial + f*spatial + k;
-        spatial_mean_delta[i] += delta[index];
+    int filter = blockIdx.x;
+
+    int i, j;
+    for(j = 0; j < batch; ++j){
+        for(i = 0; i < spatial; i += threads){
+            int index = j*spatial*filters + filter*spatial + i + id;
+            local[id] += (i+id < spatial) ? delta[index] : 0;
+        }
     }
-    spatial_mean_delta[i] *= (-1./sqrt(variance[f] + .00001f));
+
+    if(id == 0){
+        mean_delta[filter] = 0;
+        for(i = 0; i < threads; ++i){
+            mean_delta[filter] += local[i];
+        }
+        mean_delta[filter] *= (-1./sqrt(variance[filter] + .000001f));
+    }
 }
 
-extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta, float *mean_delta)
+__global__ void  fast_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
 {
-    spatial_mean_delta_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(delta, variance, batch, filters, spatial, spatial_mean_delta);
-    check_error(cudaPeekAtLastError());
-    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_mean_delta, batch, filters, mean_delta);
-    check_error(cudaPeekAtLastError());
+    const int threads = BLOCK;
+    __shared__ float local[threads];
+
+    int id = threadIdx.x;
+    local[id] = 0;
+
+    int filter = blockIdx.x;
+
+    int i, j;
+    for(j = 0; j < batch; ++j){
+        for(i = 0; i < spatial; i += threads){
+            int index = j*spatial*filters + filter*spatial + i + id;
+
+            local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
+        }
+    }
+
+    if(id == 0){
+        variance_delta[filter] = 0;
+        for(i = 0; i < threads; ++i){
+            variance_delta[filter] += local[i];
+        }
+        variance_delta[filter] *= -.5 * pow(variance[filter] + .000001f, (float)(-3./2.));
+    }
 }
 
+
 __global__ void mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -125,7 +128,7 @@
             mean_delta[i] += delta[index];
         }
     }
-    mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
+    mean_delta[i] *= (-1./sqrt(variance[i] + .000001f));
 }
 
 extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
@@ -134,6 +137,18 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+    fast_mean_delta_kernel<<<filters, BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
+{
+    fast_variance_delta_kernel<<<filters, BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
+    check_error(cudaPeekAtLastError());
+}
+
 __global__ void  mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
 {
     float scale = 1./(batch * spatial);
@@ -150,26 +165,9 @@
     mean[i] *= scale;
 }
 
-__global__ void spatial_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
-{
-    float scale = 1./(spatial*batch-1);
-    int k;
-    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if (i >= batch*filters) return;
-    int f = i%filters;
-    int b = i/filters;
-
-    variance[i] = 0;
-    for(k = 0; k < spatial; ++k){
-        int index = b*filters*spatial + f*spatial + k;
-        variance[i] += pow((x[index] - mean[f]), 2);
-    }
-    variance[i] *= scale;
-}
-
 __global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
 {
-    float scale = 1./(batch * spatial);
+    float scale = 1./(batch * spatial - 1);
     int j,k;
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
     if (i >= filters) return;
@@ -231,6 +229,7 @@
     if(i < N) Y[i*INCY] *= X[i*INCX];
 }
 
+
 extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     size_t N = batch*filters*spatial;
@@ -238,28 +237,80 @@
     check_error(cudaPeekAtLastError());
 }
 
+__global__ void  fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
+{
+    const int threads = BLOCK;
+    __shared__ float local[threads];
+
+    int id = threadIdx.x;
+    local[id] = 0;
+
+    int filter = blockIdx.x;
+
+    int i, j;
+    for(j = 0; j < batch; ++j){
+        for(i = 0; i < spatial; i += threads){
+            int index = j*spatial*filters + filter*spatial + i + id;
+            local[id] += (i+id < spatial) ? x[index] : 0;
+        }
+    }
+
+    if(id == 0){
+        mean[filter] = 0;
+        for(i = 0; i < threads; ++i){
+            mean[filter] += local[i];
+        }
+        mean[filter] /= spatial * batch;
+    }
+}
+
+__global__ void  fast_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+    const int threads = BLOCK;
+    __shared__ float local[threads];
+
+    int id = threadIdx.x;
+    local[id] = 0;
+
+    int filter = blockIdx.x;
+
+    int i, j;
+    for(j = 0; j < batch; ++j){
+        for(i = 0; i < spatial; i += threads){
+            int index = j*spatial*filters + filter*spatial + i + id;
+
+            local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
+        }
+    }
+
+    if(id == 0){
+        variance[filter] = 0;
+        for(i = 0; i < threads; ++i){
+            variance[filter] += local[i];
+        }
+        variance[filter] /= (spatial * batch - 1);
+    }
+}
+
+extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
+{
+    fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
+    check_error(cudaPeekAtLastError());
+}
+
+extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+    fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
+    check_error(cudaPeekAtLastError());
+}
+
+
 extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
 {
     mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, batch, filters, spatial, mean);
     check_error(cudaPeekAtLastError());
 }
 
-extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *spatial_mean, float *mean)
-{
-    mean_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(x, 1, filters*batch, spatial, spatial_mean);
-    check_error(cudaPeekAtLastError());
-    mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_mean, batch, filters, 1, mean);
-    check_error(cudaPeekAtLastError());
-}
-
-extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *spatial_variance, float *variance)
-{
-    spatial_variance_kernel<<<cuda_gridsize(batch*filters), BLOCK>>>(x, mean, batch, filters, spatial, spatial_variance);
-    check_error(cudaPeekAtLastError());
-    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_variance, batch, filters, variance);
-    check_error(cudaPeekAtLastError());
-}
-
 extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
 {
     variance_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, mean, batch, filters, spatial, variance);
@@ -323,3 +374,77 @@
     fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
     check_error(cudaPeekAtLastError());
 }
+
+__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
+{
+    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (id >= size) return;
+    int i = id % minw;
+    id /= minw;
+    int j = id % minh;
+    id /= minh;
+    int k = id % minc;
+    id /= minc;
+    int b = id % batch;
+
+    int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
+    int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
+    out[out_index] += add[add_index];
+}
+
+extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
+{
+    int minw = (w1 < w2) ? w1 : w2;
+    int minh = (h1 < h2) ? h1 : h2;
+    int minc = (c1 < c2) ? c1 : c2;
+
+    int stride = w1/w2;
+    int sample = w2/w1;
+    assert(stride == h1/h2);
+    assert(sample == h2/h1);
+    if(stride < 1) stride = 1;
+    if(sample < 1) sample = 1;
+
+    int size = batch * minw * minh * minc;
+    shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        float diff = truth[i] - pred[i];
+        float abs_val = abs(diff);
+        if(abs_val < 1) {
+            error[i] = diff * diff;
+            delta[i] = diff;
+        }
+        else {
+            error[i] = 2*abs_val - 1;
+            delta[i] = (diff < 0) ? -1 : 1;
+        }
+    }
+}
+
+extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+    smooth_l1_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        float diff = truth[i] - pred[i];
+        error[i] = diff * diff; //I know this is technically wrong, deal with it.
+        delta[i] = diff;
+    }
+}
+
+extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+    l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+    check_error(cudaPeekAtLastError());
+}

--
Gitblit v1.10.0