AlexeyAB
2018-03-30 23cb35e6c8eae8b59fab161036ae3f417a55c8db
src/blas_kernels.cu
@@ -145,8 +145,8 @@
    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (index >= N) return;
    
    x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
    //if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
    x[index] = x[index] - (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrtf(v[index]) + eps));
    //if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrtf(1.F-powf(B2, t)) / (1.F-powf(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
}
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
@@ -155,13 +155,27 @@
    check_error(cudaPeekAtLastError());
}
extern "C" void adam_update_gpu(float *w, float *d, float *m, float *v, float B1, float B2, float eps, float decay, float rate, int n, int batch, int t)
{
   scal_ongpu(n, B1, m, 1);
   scal_ongpu(n, B2, v, 1);
   axpy_ongpu(n, -decay*batch, w, 1, d, 1);
   axpy_ongpu(n, (1 - B1), d, 1, m, 1);
   mul_ongpu(n, d, 1, d, 1);
   axpy_ongpu(n, (1 - B2), d, 1, v, 1);
   adam_gpu(n, w, m, v, B1, B2, rate, eps, t);
   fill_ongpu(n, 0, d, 1);
}
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
    int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (index >= N) return;
    int f = (index/spatial)%filters;
    
    x[index] = (x[index] - mean[f])/(sqrt(variance[f]) + .000001f);
    x[index] = (x[index] - mean[f])/(sqrtf(variance[f]) + .000001f);
}
__global__ void normalize_delta_kernel(int N, float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -170,7 +184,7 @@
    if (index >= N) return;
    int f = (index/spatial)%filters;
    
    delta[index] = delta[index] * 1./(sqrt(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
    delta[index] = delta[index] * 1.F/(sqrtf(variance[f]) + .000001f) + variance_delta[f] * 2. * (x[index] - mean[f]) / (spatial * batch) + mean_delta[f]/(spatial*batch);
}
extern "C" void normalize_delta_gpu(float *x, float *mean, float *variance, float *mean_delta, float *variance_delta, int batch, int filters, int spatial, float *delta)
@@ -192,7 +206,7 @@
            variance_delta[i] += delta[index]*(x[index] - mean[i]);
        }
    }
    variance_delta[i] *= -.5 * pow(variance[i] + .000001f, (float)(-3./2.));
    variance_delta[i] *= -.5 * powf(variance[i] + .000001f, (float)(-3./2.));
}
__global__ void accumulate_kernel(float *x, int n, int groups, float *sum)
@@ -230,7 +244,7 @@
        for(i = 0; i < threads; ++i){
            mean_delta[filter] += local[i];
        }
        mean_delta[filter] *= (-1./sqrt(variance[filter] + .000001f));
        mean_delta[filter] *= (-1.F/sqrtf(variance[filter] + .000001f));
    }
}
@@ -259,7 +273,7 @@
        for(i = 0; i < threads; ++i){
            variance_delta[filter] += local[i];
        }
        variance_delta[filter] *= -.5 * pow(variance[filter] + .000001f, (float)(-3./2.));
        variance_delta[filter] *= -.5 * powf(variance[filter] + .000001f, (float)(-3./2.));
    }
}
@@ -276,7 +290,7 @@
            mean_delta[i] += delta[index];
        }
    }
    mean_delta[i] *= (-1./sqrt(variance[i] + .000001f));
    mean_delta[i] *= (-1.F/sqrtf(variance[i] + .000001f));
}
extern "C" void mean_delta_gpu(float *delta, float *variance, int batch, int filters, int spatial, float *mean_delta)
@@ -299,7 +313,7 @@
__global__ void  mean_kernel(float *x, int batch, int filters, int spatial, float *mean)
{
    float scale = 1./(batch * spatial);
    float scale = 1.F/(batch * spatial);
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= filters) return;
    int j,k;
@@ -315,7 +329,7 @@
__global__ void variance_kernel(float *x, float *mean, int batch, int filters, int spatial, float *variance)
{
    float scale = 1./(batch * spatial - 1);
    float scale = 1.F/(batch * spatial - 1);
    int j,k;
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if (i >= filters) return;
@@ -323,7 +337,7 @@
    for(j = 0; j < batch; ++j){
        for(k = 0; k < spatial; ++k){
            int index = j*filters*spatial + i*spatial + k;
            variance[i] += pow((x[index] - mean[i]), 2);
            variance[i] += powf((x[index] - mean[i]), 2);
        }
    }
    variance[i] *= scale;
@@ -370,7 +384,7 @@
__global__ void pow_kernel(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(i < N) Y[i*INCY] = pow(X[i*INCX], ALPHA);
    if(i < N) Y[i*INCY] = powf(X[i*INCX], ALPHA);
}
__global__ void const_kernel(int N, float ALPHA, float *X, int INCX)
@@ -474,7 +488,7 @@
        for(i = 0; i < spatial; i += threads){
            int index = j*spatial*filters + filter*spatial + i + id;
            local[id] += (i+id < spatial) ? pow((x[index] - mean[filter]), 2) : 0;
            local[id] += (i+id < spatial) ? powf((x[index] - mean[filter]), 2) : 0;
        }
    }
   __syncthreads();
@@ -646,7 +660,7 @@
    if(sample < 1) sample = 1;
    int size = batch * minw * minh * minc;
    shortcut_kernel<<<cuda_gridsize(size), BLOCK>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
    shortcut_kernel<<<cuda_gridsize(size), BLOCK, 0, get_cuda_stream()>>>(size, minw, minh, minc, stride, sample, batch, w1, h1, c1, add, w2, h2, c2, out);
    check_error(cudaPeekAtLastError());
}
@@ -769,3 +783,35 @@
    softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, get_cuda_stream()>>>(inputs, offset, batch, input, temp, output);
    check_error(cudaPeekAtLastError());
}
__global__ void upsample_kernel(size_t N, float *x, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
{
   size_t i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
   if (i >= N) return;
   int out_index = i;
   int out_w = i % (w*stride);
   i = i / (w*stride);
   int out_h = i % (h*stride);
   i = i / (h*stride);
   int out_c = i%c;
   i = i / c;
   int b = i%batch;
   int in_w = out_w / stride;
   int in_h = out_h / stride;
   int in_c = out_c;
   int in_index = b*w*h*c + in_c*w*h + in_h*w + in_w;
   if (forward) out[out_index] += scale * x[in_index];
   else atomicAdd(x + in_index, scale * out[out_index]);
}
extern "C" void upsample_gpu(float *in, int w, int h, int c, int batch, int stride, int forward, float scale, float *out)
{
   size_t size = w*h*c*batch*stride*stride;
   upsample_kernel << <cuda_gridsize(size), BLOCK >> >(size, in, w, h, c, batch, stride, forward, scale, out);
   check_error(cudaPeekAtLastError());
}