#include "cuda_runtime.h"
|
#include "curand.h"
|
#include "cublas_v2.h"
|
|
extern "C" {
|
#include "convolutional_layer.h"
|
#include "gemm.h"
|
#include "blas.h"
|
#include "im2col.h"
|
#include "col2im.h"
|
#include "utils.h"
|
#include "cuda.h"
|
}
|
|
__global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
|
{
|
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
if (f >= n) return;
|
int i = 0;
|
float mean = 0;
|
for(i = 0; i < size; ++i){
|
mean += abs(filters[f*size + i]);
|
}
|
mean = mean / size;
|
for(i = 0; i < size; ++i){
|
binary[f*size + i] = (filters[f*size + i] > 0) ? mean : -mean;
|
}
|
}
|
|
__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
|
{
|
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
int filter = blockIdx.y;
|
int batch = blockIdx.z;
|
|
if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
|
}
|
|
void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
{
|
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
dim3 dimBlock(BLOCK, 1, 1);
|
|
scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
check_error(cudaPeekAtLastError());
|
}
|
|
__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
{
|
__shared__ float part[BLOCK];
|
int i,b;
|
int filter = blockIdx.x;
|
int p = threadIdx.x;
|
float sum = 0;
|
for(b = 0; b < batch; ++b){
|
for(i = 0; i < size; i += BLOCK){
|
int index = p + i + size*(filter + n*b);
|
sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
|
}
|
}
|
part[p] = sum;
|
__syncthreads();
|
if (p == 0) {
|
for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
|
}
|
}
|
|
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
|
{
|
binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
|
check_error(cudaPeekAtLastError());
|
}
|
|
void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
|
{
|
backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
|
check_error(cudaPeekAtLastError());
|
}
|
|
__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
|
{
|
int offset = blockIdx.x * blockDim.x + threadIdx.x;
|
int filter = blockIdx.y;
|
int batch = blockIdx.z;
|
|
if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
|
}
|
|
void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
|
{
|
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
|
dim3 dimBlock(BLOCK, 1, 1);
|
|
add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
|
check_error(cudaPeekAtLastError());
|
}
|
|
__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
|
{
|
__shared__ float part[BLOCK];
|
int i,b;
|
int filter = blockIdx.x;
|
int p = threadIdx.x;
|
float sum = 0;
|
for(b = 0; b < batch; ++b){
|
for(i = 0; i < size; i += BLOCK){
|
int index = p + i + size*(filter + n*b);
|
sum += (p+i < size) ? delta[index] : 0;
|
}
|
}
|
part[p] = sum;
|
__syncthreads();
|
if (p == 0) {
|
for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
|
}
|
}
|
|
__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
|
{
|
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
|
int f1 = index / n;
|
int f2 = index % n;
|
if (f2 <= f1) return;
|
|
float sum = 0;
|
float norm1 = 0;
|
float norm2 = 0;
|
int b, i;
|
for(b = 0; b < batch; ++b){
|
for(i = 0; i < size; ++i){
|
int i1 = b * size * n + f1 * size + i;
|
int i2 = b * size * n + f2 * size + i;
|
sum += output[i1] * output[i2];
|
norm1 += output[i1] * output[i1];
|
norm2 += output[i2] * output[i2];
|
}
|
}
|
norm1 = sqrt(norm1);
|
norm2 = sqrt(norm2);
|
float norm = norm1 * norm2;
|
sum = sum / norm;
|
for(b = 0; b < batch; ++b){
|
for(i = 0; i < size; ++i){
|
int i1 = b * size * n + f1 * size + i;
|
int i2 = b * size * n + f2 * size + i;
|
delta[i1] += - scale * sum * output[i2] / norm;
|
delta[i2] += - scale * sum * output[i1] / norm;
|
}
|
}
|
}
|
|
void dot_error_gpu(layer l)
|
{
|
dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
|
check_error(cudaPeekAtLastError());
|
}
|
|
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
|
{
|
backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
|
check_error(cudaPeekAtLastError());
|
}
|
|
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
{
|
int i;
|
int m = l.n;
|
int k = l.size*l.size*l.c;
|
int n = convolutional_out_height(l)*
|
convolutional_out_width(l);
|
|
fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
|
if(l.binary){
|
binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
|
swap_binary(&l);
|
}
|
|
for(i = 0; i < l.batch; ++i){
|
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
|
float * a = l.filters_gpu;
|
float * b = l.col_image_gpu;
|
float * c = l.output_gpu;
|
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
|
}
|
|
if (l.batch_normalize) {
|
if (state.train) {
|
fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.mean_gpu);
|
fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.variance_gpu);
|
|
scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
|
axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
|
scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
|
axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
|
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
|
normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
|
copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
|
} else {
|
normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
|
}
|
|
scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
|
}
|
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
|
|
activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
|
if(l.dot > 0) dot_error_gpu(l);
|
if(l.binary) swap_binary(&l);
|
}
|
|
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
|
{
|
int i;
|
int m = l.n;
|
int n = l.size*l.size*l.c;
|
int k = convolutional_out_height(l)*
|
convolutional_out_width(l);
|
|
gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
|
|
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
|
|
if(l.batch_normalize){
|
backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
|
|
scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
|
|
fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.mean_delta_gpu);
|
fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.variance_delta_gpu);
|
normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
|
}
|
|
for(i = 0; i < l.batch; ++i){
|
float * a = l.delta_gpu;
|
float * b = l.col_image_gpu;
|
float * c = l.filter_updates_gpu;
|
|
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
|
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
|
|
if(state.delta){
|
if(l.binary) swap_binary(&l);
|
float * a = l.filters_gpu;
|
float * b = l.delta_gpu;
|
float * c = l.col_image_gpu;
|
|
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
|
|
col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
|
if(l.binary) swap_binary(&l);
|
}
|
}
|
}
|
|
void pull_convolutional_layer(convolutional_layer layer)
|
{
|
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
|
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
if (layer.batch_normalize){
|
cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
|
cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
|
cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
|
}
|
}
|
|
void push_convolutional_layer(convolutional_layer layer)
|
{
|
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
|
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
|
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
|
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
|
if (layer.batch_normalize){
|
cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
|
cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
|
cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
|
}
|
}
|
|
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
|
{
|
int size = layer.size*layer.size*layer.c*layer.n;
|
|
axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
|
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
|
|
axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
|
scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
|
|
axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
|
axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
|
scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
|
}
|