From c7b10ceadb1a78e7480d281444a31ae2a7dc1b05 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 06 May 2016 23:25:16 +0000
Subject: [PATCH] so much need to commit

---
 src/blas_kernels.cu |  189 +++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 189 insertions(+), 0 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 98366f8..ac537d8 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -9,6 +9,137 @@
 #include "utils.h"
 }
 
+__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
+}
+
+void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+    }
+}
+
+void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+{
+    int offset = blockIdx.x * blockDim.x + threadIdx.x;
+    int filter = blockIdx.y;
+    int batch = blockIdx.z;
+
+    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
+}
+
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
+
+    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    __shared__ float part[BLOCK];
+    int i,b;
+    int filter = blockIdx.x;
+    int p = threadIdx.x;
+    float sum = 0;
+    for(b = 0; b < batch; ++b){
+        for(i = 0; i < size; i += BLOCK){
+            int index = p + i + size*(filter + n*b);
+            sum += (p+i < size) ? delta[index] : 0;
+        }
+    }
+    part[p] = sum;
+    __syncthreads();
+    if (p == 0) {
+        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
+    }
+}
+
+/*
+__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
+{
+    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    int f1 = index / n;
+    int f2 = index % n;
+    if (f2 <= f1) return;
+    
+    float sum = 0;
+    float norm1 = 0;
+    float norm2 = 0;
+    int b, i;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            sum += output[i1] * output[i2];
+            norm1 += output[i1] * output[i1];
+            norm2 += output[i2] * output[i2];
+        }
+    }
+    norm1 = sqrt(norm1);
+    norm2 = sqrt(norm2);
+    float norm = norm1 * norm2;
+    sum = sum / norm;
+    for(b = 0; b <  batch; ++b){
+        for(i = 0; i < size; ++i){
+            int i1 = b * size * n + f1 * size + i;
+            int i2 = b * size * n + f2 * size + i;
+            delta[i1] += - scale * sum * output[i2] / norm;
+            delta[i2] += - scale * sum * output[i1] / norm;
+        }
+    }
+}
+
+void dot_error_gpu(layer l)
+{
+    dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
+    check_error(cudaPeekAtLastError());
+}
+*/
+
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+{
+    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
+    check_error(cudaPeekAtLastError());
+}
+
+
 __global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
 {
     int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -199,6 +330,12 @@
     if(i < N) X[i*INCX] = ALPHA;
 }
 
+__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+}
+
 __global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
 {
     int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -363,6 +500,13 @@
     check_error(cudaPeekAtLastError());
 }
 
+extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+    constrain_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+    check_error(cudaPeekAtLastError());
+}
+
+
 extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
 {
     scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
@@ -448,3 +592,48 @@
     l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
     check_error(cudaPeekAtLastError());
 }
+
+
+__global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] = s[i]*a[i] + (1-s[i])*(b ? b[i] : 0);
+    }
+}
+
+extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
+{
+    weighted_sum_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, c);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        if(da) da[i] += dc[i] * s[i];
+        db[i] += dc[i] * (1-s[i]);
+        ds[i] += dc[i] * a[i] + dc[i] * -b[i];
+    }
+}
+
+extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
+{
+    weighted_delta_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, da, db, ds, dc);
+    check_error(cudaPeekAtLastError());
+}
+
+__global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
+{
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(i < n){
+        c[i] += a[i]*b[i];
+    }
+}
+
+extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
+{
+    mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
+    check_error(cudaPeekAtLastError());
+}

--
Gitblit v1.10.0