Joseph Redmon
2016-03-14 68213b835b9f15cb449ad2037a8b51c17a3de07b
src/blas_kernels.cu
@@ -1,6 +1,7 @@
#include "cuda_runtime.h"
#include "curand.h"
#include "cublas_v2.h"
#include <assert.h>
extern "C" {
#include "blas.h"
@@ -14,7 +15,7 @@
    if (index >= N) return;
    int f = (index/spatial)%filters;
    
    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .00001f);
    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .000001f);
}
__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -23,7 +24,7 @@
    if (index >= N) return;
    int f = (index/spatial)%filters;
    
    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .00001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
}
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -45,29 +46,7 @@
            variance_delta[i] += delta[index]*(x[index] - mean[i]);
        }
    }
    variance_delta[i] *= -.5 * pow(variance[i] + .00001f, (float)(-3./2.));
}
__global__ void spatial_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= batch*filters) return;
    int f = i%filters;
    int b = i/filters;
    int k;
    spatial_variance_delta[i] = 0;
    for (k = 0; k < spatial; ++k) {
        int index = b*filters*spatial + f*spatial + k;
        spatial_variance_delta[i] += delta[index]*(x[index] - mean[f]);
    }
    spatial_variance_delta[i] *= -.5 * pow(variance[f] + .00001f, (float)(-3./2.));
}
extern "C" void variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
{
    variance_delta_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
    check_error(cudaPeekAtLastError());
    variance_delta[i] *= -.5 * pow(variance[i] + .000001f, (float)(-3./2.));
}
__global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
@@ -81,38 +60,62 @@
    }
}
extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *spatial_variance_delta, float *variance_delta)
__global__ void fast_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
{
    spatial_variance_delta_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, spatial_variance_delta);
    check_error(cudaPeekAtLastError());
    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_variance_delta, batch, filters, variance_delta);
    check_error(cudaPeekAtLastError());
}
    const int threads = BLOCK;
    __shared__ float local[threads];
__global__ void spatial_mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= batch*filters) return;
    int f = i%filters;
    int b = i/filters;
    int id = threadIdx.x;
    local[id] = 0;
    int k;
    spatial_mean_delta[i] = 0;
    for (k = 0; k < spatial; ++k) {
        int index = b*filters*spatial + f*spatial + k;
        spatial_mean_delta[i] += delta[index];
    int filter = blockIdx.x;
    int i, j;
    for(j = 0; j < batch; ++j){
        for(i = 0; i < spatial; i += threads){
            int index = j*spatial*filters + filter*spatial + i + id;
            local[id] += (i+id < spatial) ? delta[index] : 0;
        }
    }
    spatial_mean_delta[i] *= (-1./sqrt(variance[f] + .00001f));
    if(id == 0){
        mean_delta[filter] = 0;
        for(i = 0; i < threads; ++i){
            mean_delta[filter] += local[i];
        }
        mean_delta[filter] *= (-1./sqrt(variance[filter] + .000001f));
    }
}
extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *spatial_mean_delta, float *mean_delta)
__global__ void  fast_variance_delta_kernel(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
{
    spatial_mean_delta_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(delta, variance, batch, filters, spatial, spatial_mean_delta);
    check_error(cudaPeekAtLastError());
    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_mean_delta, batch, filters, mean_delta);
    check_error(cudaPeekAtLastError());
    const int threads = BLOCK;
    __shared__ float local[threads];
    int id = threadIdx.x;
    local[id] = 0;
    int filter = blockIdx.x;
    int i, j;
    for(j = 0; j < batch; ++j){
        for(i = 0; i < spatial; i += threads){
            int index = j*spatial*filters + filter*spatial + i + id;
            local[id] += (i+id < spatial) ? delta[index]*(x[index] - mean[filter]) : 0;
        }
    }
    if(id == 0){
        variance_delta[filter] = 0;
        for(i = 0; i < threads; ++i){
            variance_delta[filter] += local[i];
        }
        variance_delta[filter] *= -.5 * pow(variance[filter] + .000001f, (float)(-3./2.));
    }
}
__global__ void mean_delta_kernel(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -125,7 +128,7 @@
            mean_delta[i] += delta[index];
        }
    }
    mean_delta[i] *= (-1./sqrt(variance[i] + .00001f));
    mean_delta[i] *= (-1./sqrt(variance[i] + .000001f));
}
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
@@ -134,6 +137,18 @@
    check_error(cudaPeekAtLastError());
}
extern "C" void fast_mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
{
    fast_mean_delta_kernel<<<filters, BLOCK>>>(delta, variance, batch, filters, spatial, mean_delta);
    check_error(cudaPeekAtLastError());
}
extern "C" void fast_variance_delta_gpu(float *x, float *delta, float *mean, float *variance, int batch, int filters, int spatial, float *variance_delta)
{
    fast_variance_delta_kernel<<<filters, BLOCK>>>(x, delta, mean, variance, batch, filters, spatial, variance_delta);
    check_error(cudaPeekAtLastError());
}
__global__ void  mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
{
    float scale = 1./(batch * spatial);
@@ -150,26 +165,9 @@
    mean[i] *= scale;
}
__global__ void spatial_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    float scale = 1./(spatial*batch-1);
    int k;
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= batch*filters) return;
    int f = i%filters;
    int b = i/filters;
    variance[i] = 0;
    for(k = 0; k < spatial; ++k){
        int index = b*filters*spatial + f*spatial + k;
        variance[i] += pow((x[index] - mean[f]), 2);
    }
    variance[i] *= scale;
}
__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    float scale = 1./(batch * spatial);
    float scale = 1./(batch * spatial - 1);
    int j,k;
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= filters) return;
@@ -231,6 +229,7 @@
    if(i < N) Y[i*INCY] *= X[i*INCX];
}
extern "C" void normalize_gpu(float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
    size_t N = batch*filters*spatial;
@@ -238,28 +237,80 @@
    check_error(cudaPeekAtLastError());
}
__global__ void  fast_mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
{
    const int threads = BLOCK;
    __shared__ float local[threads];
    int id = threadIdx.x;
    local[id] = 0;
    int filter = blockIdx.x;
    int i, j;
    for(j = 0; j < batch; ++j){
        for(i = 0; i < spatial; i += threads){
            int index = j*spatial*filters + filter*spatial + i + id;
            local[id] += (i+id < spatial) ? x[index] : 0;
        }
    }
    if(id == 0){
        mean[filter] = 0;
        for(i = 0; i < threads; ++i){
            mean[filter] += local[i];
        }
        mean[filter] /= spatial * batch;
    }
}
__global__ void  fast_variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    const int threads = BLOCK;
    __shared__ float local[threads];
    int id = threadIdx.x;
    local[id] = 0;
    int filter = blockIdx.x;
    int i, j;
    for(j = 0; j < batch; ++j){
        for(i = 0; i < spatial; i += threads){
            int index = j*spatial*filters + filter*spatial + i + id;
            local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
        }
    }
    if(id == 0){
        variance[filter] = 0;
        for(i = 0; i < threads; ++i){
            variance[filter] += local[i];
        }
        variance[filter] /= (spatial * batch - 1);
    }
}
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
    fast_mean_kernel<<<filters, BLOCK>>>(x, batch, filters, spatial, mean);
    check_error(cudaPeekAtLastError());
}
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    fast_variance_kernel<<<filters, BLOCK>>>(x, mean, batch, filters, spatial, variance);
    check_error(cudaPeekAtLastError());
}
extern "C" void mean_gpu(float *x, int batch, int filters, int spatial, float *mean)
{
    mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, batch, filters, spatial, mean);
    check_error(cudaPeekAtLastError());
}
extern "C" void fast_mean_gpu(float *x, int batch, int filters, int spatial, float *spatial_mean, float *mean)
{
    mean_kernel<<<cuda_gridsize(filters*batch), BLOCK>>>(x, 1, filters*batch, spatial, spatial_mean);
    check_error(cudaPeekAtLastError());
    mean_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_mean, batch, filters, 1, mean);
    check_error(cudaPeekAtLastError());
}
extern "C" void fast_variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *spatial_variance, float *variance)
{
    spatial_variance_kernel<<<cuda_gridsize(batch*filters), BLOCK>>>(x, mean, batch, filters, spatial, spatial_variance);
    check_error(cudaPeekAtLastError());
    accumulate_kernel<<<cuda_gridsize(filters), BLOCK>>>(spatial_variance, batch, filters, variance);
    check_error(cudaPeekAtLastError());
}
extern "C" void variance_gpu(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    variance_kernel<<<cuda_gridsize(filters), BLOCK>>>(x, mean, batch, filters, spatial, variance);
@@ -323,3 +374,77 @@
    fill_kernel<<<cuda_gridsize(N), BLOCK>>>(N, ALPHA, X, INCX);
    check_error(cudaPeekAtLastError());
}
__global__ void shortcut_kernel(int size, int minw, int minh, int minc, int stride, int sample, int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
{
    int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (id >= size) return;
    int i = id % minw;
    id /= minw;
    int j = id % minh;
    id /= minh;
    int k = id % minc;
    id /= minc;
    int b = id % batch;
    int out_index = i*sample + w2*(j*sample + h2*(k + c2*b));
    int add_index = i*stride + w1*(j*stride + h1*(k + c1*b));
    out[out_index] += add[add_index];
}
extern "C" void shortcut_gpu(int batch, int w1, int h1, int c1, float *add, int w2, int h2, int c2, float *out)
{
    int minw = (w1 < w2) ? w1 : w2;
    int minh = (h1 < h2) ? h1 : h2;
    int minc = (c1 < c2) ? c1 : c2;
    int stride = w1/w2;
    int sample = w2/w1;
    assert(stride == h1/h2);
    assert(sample == h2/h1);
    if(stride < 1) stride = 1;
    if(sample < 1) sample = 1;
    int size = batch * minw * minh * minc;
    shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
    check_error(cudaPeekAtLastError());
}
__global__ void smooth_l1_kernel(int n, float *pred, float *truth, float *delta, float *error)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(i < n){
        float diff = truth[i] - pred[i];
        float abs_val = abs(diff);
        if(abs_val < 1) {
            error[i] = diff * diff;
            delta[i] = diff;
        }
        else {
            error[i] = 2*abs_val - 1;
            delta[i] = (diff < 0) ? -1 : 1;
        }
    }
}
extern "C" void smooth_l1_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
    smooth_l1_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
    check_error(cudaPeekAtLastError());
}
__global__ void l2_kernel(int n, float *pred, float *truth, float *delta, float *error)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(i < n){
        float diff = truth[i] - pred[i];
        error[i] = diff * diff; //I know this is technically wrong, deal with it.
        delta[i] = diff;
    }
}
extern "C" void l2_gpu(int n, float *pred, float *truth, float *delta, float *error)
{
    l2_kernel<<<cuda_gridsize(n), BLOCK>>>(n, pred, truth, delta, error);
    check_error(cudaPeekAtLastError());
}