From ae43c2bc32fbb838bfebeeaf2c2b058ccab5c83c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@burninator.cs.washington.edu>
Date: Thu, 23 Jun 2016 05:31:14 +0000
Subject: [PATCH] hi
---
src/blas_kernels.cu | 585 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 581 insertions(+), 4 deletions(-)
diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index d6f7143..ac537d8 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -1,6 +1,315 @@
+#include "cuda_runtime.h"
+#include "curand.h"
+#include "cublas_v2.h"
+#include <assert.h>
+
extern "C" {
#include "blas.h"
#include "cuda.h"
+#include "utils.h"
+}
+
+__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
+{
+ int offset = blockIdx.x * blockDim.x + threadIdx.x;
+ int filter = blockIdx.y;
+ int batch = blockIdx.z;
+
+ if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
+}
+
+void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+ dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+ dim3 dimBlock(BLOCK, 1, 1);
+
+ scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+ __shared__ float part[BLOCK];
+ int i,b;
+ int filter = blockIdx.x;
+ int p = threadIdx.x;
+ float sum = 0;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < size; i += BLOCK){
+ int index = p + i + size*(filter + n*b);
+ sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
+ }
+ }
+ part[p] = sum;
+ __syncthreads();
+ if (p == 0) {
+ for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+ }
+}
+
+void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+ backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+{
+ int offset = blockIdx.x * blockDim.x + threadIdx.x;
+ int filter = blockIdx.y;
+ int batch = blockIdx.z;
+
+ if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
+}
+
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+ dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+ dim3 dimBlock(BLOCK, 1, 1);
+
+ add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
+{
+ __shared__ float part[BLOCK];
+ int i,b;
+ int filter = blockIdx.x;
+ int p = threadIdx.x;
+ float sum = 0;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < size; i += BLOCK){
+ int index = p + i + size*(filter + n*b);
+ sum += (p+i < size) ? delta[index] : 0;
+ }
+ }
+ part[p] = sum;
+ __syncthreads();
+ if (p == 0) {
+ for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
+ }
+}
+
+/*
+__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
+{
+ int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ int f1 = index / n;
+ int f2 = index % n;
+ if (f2 <= f1) return;
+
+ float sum = 0;
+ float norm1 = 0;
+ float norm2 = 0;
+ int b, i;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < size; ++i){
+ int i1 = b * size * n + f1 * size + i;
+ int i2 = b * size * n + f2 * size + i;
+ sum += output[i1] * output[i2];
+ norm1 += output[i1] * output[i1];
+ norm2 += output[i2] * output[i2];
+ }
+ }
+ norm1 = sqrt(norm1);
+ norm2 = sqrt(norm2);
+ float norm = norm1 * norm2;
+ sum = sum / norm;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < size; ++i){
+ int i1 = b * size * n + f1 * size + i;
+ int i2 = b * size * n + f2 * size + i;
+ delta[i1] += - scale * sum * output[i2] / norm;
+ delta[i2] += - scale * sum * output[i1] / norm;
+ }
+ }
+}
+
+void dot_error_gpu(layer l)
+{
+ dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
+ check_error(cudaPeekAtLastError());
+}
+*/
+
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+{
+ backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
+ check_error(cudaPeekAtLastError());
+}
+
+
+__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
+{
+ int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (index >= N) return;
+ int f = (index/spatial)%filters;
+
+ x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .000001f);
+}
+
+__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
+{
+ int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (index >= N) return;
+ int f = (index/spatial)%filters;
+
+ delta[index] = delta[index] * 1./(sqrt(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
+}
+
+extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
+{
+ size_t N = batch*filters*spatial;
+ normalize_delta_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, mean_delta, variance_delta, batch, filters, spatial, delta);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= filters) return;
+ int j,k;
+ variance_delta[i] = 0;
+ for(j = 0; j < batch; ++j){
+ for(k = 0; k < spatial; ++k){
+ int index = j*filters*spatial + i*spatial + k;
+ variance_delta[i] += delta[index]*(x[index] - mean[i]);
+ }
+ }
+ variance_delta[i] *= -.5 * pow(variance[i] + .000001f, (float)(-3./2.));
+}
+
+__global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
+{
+ int k;
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= groups) return;
+ sum[i] = 0;
+ for(k = 0; k < n; ++k){
+ sum[i] += x[k*groups + i];
+ }
+}
+
+__global__ void fast_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+ const int threads = BLOCK;
+ __shared__ float local[threads];
+
+ int id = threadIdx.x;
+ local[id] = 0;
+
+ int filter = blockIdx.x;
+
+ int i, j;
+ for(j = 0; j < batch; ++j){
+ for(i = 0; i < spatial; i += threads){
+ int index = j*spatial*filters + filter*spatial + i + id;
+ local[id] += (i+id < spatial) ? delta[index] : 0;
+ }
+ }
+
+ if(id == 0){
+ mean_delta[filter] = 0;
+ for(i = 0; i < threads; ++i){
+ mean_delta[filter] += local[i];
+ }
+ mean_delta[filter] *= (-1./sqrt(variance[filter] + .000001f));
+ }
+}
+
+__global__ void fast_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
+{
+ const int threads = BLOCK;
+ __shared__ float local[threads];
+
+ int id = threadIdx.x;
+ local[id] = 0;
+
+ int filter = blockIdx.x;
+
+ int i, j;
+ for(j = 0; j < batch; ++j){
+ for(i = 0; i < spatial; i += threads){
+ int index = j*spatial*filters + filter*spatial + i + id;
+
+ local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
+ }
+ }
+
+ if(id == 0){
+ variance_delta[filter] = 0;
+ for(i = 0; i < threads; ++i){
+ variance_delta[filter] += local[i];
+ }
+ variance_delta[filter] *= -.5 * pow(variance[filter] + .000001f, (float)(-3./2.));
+ }
+}
+
+
+__global__ void mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= filters) return;
+ int j,k;
+ mean_delta[i] = 0;
+ for (j = 0; j < batch; ++j) {
+ for (k = 0; k < spatial; ++k) {
+ int index = j*filters*spatial + i*spatial + k;
+ mean_delta[i] += delta[index];
+ }
+ }
+ mean_delta[i] *= (-1./sqrt(variance[i] + .000001f));
+}
+
+extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+ mean_delta_kernel<<<cuda_gridsize(filters), BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
+ check_error(cudaPeekAtLastError());
+}
+
+extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
+{
+ fast_mean_delta_kernel<<<filters, BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
+ check_error(cudaPeekAtLastError());
+}
+
+extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
+{
+ fast_variance_delta_kernel<<<filters, BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
+{
+ float scale = 1./(batch * spatial);
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= filters) return;
+ int j,k;
+ mean[i] = 0;
+ for(j = 0; j < batch; ++j){
+ for(k = 0; k < spatial; ++k){
+ int index = j*filters*spatial + i*spatial + k;
+ mean[i] += x[index];
+ }
+ }
+ mean[i] *= scale;
+}
+
+__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+ float scale = 1./(batch * spatial - 1);
+ int j,k;
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= filters) return;
+ variance[i] = 0;
+ for(j = 0; j < batch; ++j){
+ for(k = 0; k < spatial; ++k){
+ int index = j*filters*spatial + i*spatial + k;
+ variance[i] += pow((x[index] - mean[i]), 2);
+ }
+ }
+ variance[i] *= scale;
}
__global__ void axpy_kernel(int N, float ALPHA, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
@@ -9,16 +318,40 @@
if(i < N) Y[OFFY+i*INCY] += ALPHA*X[OFFX+i*INCX];
}
+__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
+}
+
+__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < N) X[i*INCX] = ALPHA;
+}
+
+__global__ void constrain_kernel(int N, float ALPHA, float *X, int INCX)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < N) X[i*INCX] = min(ALPHA, max(-ALPHA, X[i*INCX]));
+}
+
__global__ void scal_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(i < N) X[i*INCX] *= ALPHA;
}
-__global__ void mask_kernel(int n, float *x, float *mask, int mod)
+__global__ void fill_kernel(int N, float ALPHA, float *X, int INCX)
{
int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
- if(i < n) x[i] = (i%mod && !mask[(i/mod)*mod]) ? 0 : x[i];
+ if(i < N) X[i*INCX] = ALPHA;
+}
+
+__global__ void mask_kernel(int n, float *x, float mask_num, float *mask)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n && mask[i] == mask_num) x[i] = mask_num;
}
__global__ void copy_kernel(int N, float *X, int OFFX, int INCX, float *Y, int OFFY, int INCY)
@@ -27,11 +360,111 @@
if(i < N) Y[i*INCY + OFFY] = X[i*INCX + OFFX];
}
+__global__ void mul_kernel(int N, float *X, int INCX, float *Y, int INCY)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < N) Y[i*INCY] *= X[i*INCX];
+}
+
+
+extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
+{
+ size_t N = batch*filters*spatial;
+ normalize_kernel<<<cuda_gridsize(N), BLOCK>>>(N, x, mean, variance, batch, filters, spatial);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
+{
+ const int threads = BLOCK;
+ __shared__ float local[threads];
+
+ int id = threadIdx.x;
+ local[id] = 0;
+
+ int filter = blockIdx.x;
+
+ int i, j;
+ for(j = 0; j < batch; ++j){
+ for(i = 0; i < spatial; i += threads){
+ int index = j*spatial*filters + filter*spatial + i + id;
+ local[id] += (i+id < spatial) ? x[index] : 0;
+ }
+ }
+
+ if(id == 0){
+ mean[filter] = 0;
+ for(i = 0; i < threads; ++i){
+ mean[filter] += local[i];
+ }
+ mean[filter] /= spatial * batch;
+ }
+}
+
+__global__ void fast_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+ const int threads = BLOCK;
+ __shared__ float local[threads];
+
+ int id = threadIdx.x;
+ local[id] = 0;
+
+ int filter = blockIdx.x;
+
+ int i, j;
+ for(j = 0; j < batch; ++j){
+ for(i = 0; i < spatial; i += threads){
+ int index = j*spatial*filters + filter*spatial + i + id;
+
+ local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
+ }
+ }
+
+ if(id == 0){
+ variance[filter] = 0;
+ for(i = 0; i < threads; ++i){
+ variance[filter] += local[i];
+ }
+ variance[filter] /= (spatial * batch - 1);
+ }
+}
+
+extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
+{
+ fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
+ check_error(cudaPeekAtLastError());
+}
+
+extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+ fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
+ check_error(cudaPeekAtLastError());
+}
+
+
+extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
+{
+ mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, batch, filters, spatial, mean);
+ check_error(cudaPeekAtLastError());
+}
+
+extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
+{
+ variance_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, mean, batch, filters, spatial, variance);
+ check_error(cudaPeekAtLastError());
+}
+
extern "C" void axpy_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
{
axpy_ongpu_offset(N, ALPHA, X, 0, INCX, Y, 0, INCY);
}
+extern "C" void pow_ongpu(int N, float ALPHA, float * X, int INCX, float * Y, int INCY)
+{
+ pow_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX, Y, INCY);
+ check_error(cudaPeekAtLastError());
+}
+
extern "C" void axpy_ongpu_offset(int N, float ALPHA, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
axpy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, OFFX, INCX, Y, OFFY, INCY);
@@ -43,20 +476,164 @@
copy_ongpu_offset(N, X, 0, INCX, Y, 0, INCY);
}
+extern "C" void mul_ongpu(int N, float * X, int INCX, float * Y, int INCY)
+{
+ mul_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, INCX, Y, INCY);
+ check_error(cudaPeekAtLastError());
+}
+
extern "C" void copy_ongpu_offset(int N, float * X, int OFFX, int INCX, float * Y, int OFFY, int INCY)
{
copy_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, OFFX, INCX, Y, OFFY, INCY);
check_error(cudaPeekAtLastError());
}
-extern "C" void mask_ongpu(int N, float * X, float * mask, float mod)
+extern "C" void mask_ongpu(int N, float * X, float mask_num, float * mask)
{
- mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask, mod);
+ mask_kernel<<<cuda_gridsize(N), BLOCK>>>(N, X, mask_num, mask);
check_error(cudaPeekAtLastError());
}
+extern "C" void const_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+ const_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+ check_error(cudaPeekAtLastError());
+}
+
+extern "C" void constrain_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+ constrain_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+ check_error(cudaPeekAtLastError());
+}
+
+
extern "C" void scal_ongpu(int N, float ALPHA, float * X, int INCX)
{
scal_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
check_error(cudaPeekAtLastError());
}
+
+extern "C" void fill_ongpu(int N, float ALPHA, float * X, int INCX)
+{
+ fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
+{
+ int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (id >= size) return;
+ int i = id % minw;
+ id /= minw;
+ int j = id % minh;
+ id /= minh;
+ int k = id % minc;
+ id /= minc;
+ int b = id % batch;
+
+ int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
+ int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
+ out[out_index] += add[add_index];
+}
+
+extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
+{
+ int minw = (w1 < w2) ? w1 : w2;
+ int minh = (h1 < h2) ? h1 : h2;
+ int minc = (c1 < c2) ? c1 : c2;
+
+ int stride = w1/w2;
+ int sample = w2/w1;
+ assert(stride == h1/h2);
+ assert(sample == h2/h1);
+ if(stride < 1) stride = 1;
+ if(sample < 1) sample = 1;
+
+ int size = batch * minw * minh * minc;
+ shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n){
+ float diff = truth[i] - pred[i];
+ float abs_val = abs(diff);
+ if(abs_val < 1) {
+ error[i] = diff * diff;
+ delta[i] = diff;
+ }
+ else {
+ error[i] = 2*abs_val - 1;
+ delta[i] = (diff < 0) ? -1 : 1;
+ }
+ }
+}
+
+extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+ smooth_l1_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n){
+ float diff = truth[i] - pred[i];
+ error[i] = diff * diff; //I know this is technically wrong, deal with it.
+ delta[i] = diff;
+ }
+}
+
+extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
+{
+ l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
+ check_error(cudaPeekAtLastError());
+}
+
+
+__global__ void weighted_sum_kernel(int n, float *a, float *b, float *s, float *c)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n){
+ c[i] = s[i]*a[i] + (1-s[i])*(b ? b[i] : 0);
+ }
+}
+
+extern "C" void weighted_sum_gpu(float *a, float *b, float *s, int num, float *c)
+{
+ weighted_sum_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, c);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void weighted_delta_kernel(int n, float *a, float *b, float *s, float *da, float *db, float *ds, float *dc)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n){
+ if(da) da[i] += dc[i] * s[i];
+ db[i] += dc[i] * (1-s[i]);
+ ds[i] += dc[i] * a[i] + dc[i] * -b[i];
+ }
+}
+
+extern "C" void weighted_delta_gpu(float *a, float *b, float *s, float *da, float *db, float *ds, int num, float *dc)
+{
+ weighted_delta_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, s, da, db, ds, dc);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void mult_add_into_kernel(int n, float *a, float *b, float *c)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if(i < n){
+ c[i] += a[i]*b[i];
+ }
+}
+
+extern "C" void mult_add_into_gpu(int num, float *a, float *b, float *c)
+{
+ mult_add_into_kernel<<<cuda_gridsize(num), BLOCK>>>(num, a, b, c);
+ check_error(cudaPeekAtLastError());
+}
--
Gitblit v1.10.0