Alexey
2018-02-23 175e47317ee473e7aa15a23221d9954a6952a613
src/convolutional_kernels.cu
@@ -2,6 +2,10 @@
#include "curand.h"
#include "cublas_v2.h"
#ifdef CUDNN
#pragma comment(lib, "cudnn.lib")
#endif
extern "C" {
#include "convolutional_layer.h"
#include "batchnorm_layer.h"
@@ -70,6 +74,40 @@
    check_error(cudaPeekAtLastError());
}
__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
{
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
   if (idx < size) output_f16[idx] = __float2half(input_f32[idx]);
   //if (idx < size) *((unsigned short *)output_f16 + idx) = __float2half(input_f32[idx]);
}
void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) {
   cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16);
}
__global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
{
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
   if (idx < size) output_f32[idx] = __half2float(input_f16[idx]);
   //if (idx < size) output_f32[idx] = __half2float(*((unsigned short *)input_f16 + idx));
}
void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) {
   cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32);
}
half *cuda_make_f16_from_f32_array(float *src, size_t n)
{
   half *dst16;
   size_t size = sizeof(half)*n;
   check_error(cudaMalloc((void **)&dst16, size));
   if (src) {
      cuda_convert_f32_to_f16(src, n, dst16);
   }
   if (!dst16) error("Cuda malloc failed\n");
   return dst16;
}
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
@@ -86,9 +124,57 @@
    }
#ifdef CUDNN
    float one = 1;
   //float one = 1;  // alpha[0], beta[0] is float for HALF and FLOAT
   float alpha = 1, beta = 0;
#ifdef CUDNN_HALF
   // Note: For improved performance it is advised to use beta[0] = 0.0.
   // For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
   // 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
   // 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
   // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
   const size_t input16_size = l.batch*l.c*l.w*l.h;
   static size_t max_input16_size = input16_size;
   static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size);
   const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w;
   static size_t max_output16_size = output16_size;
   static half* output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
   if (max_input16_size < input16_size) {
      max_input16_size = input16_size;
      cuda_free((float *)input16);
      input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size);
   }
   if (max_output16_size < output16_size) {
      max_output16_size = output16_size;
      cuda_free((float *)output16);
      output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
   }
   cuda_convert_f32_to_f16(state.input, input16_size, input16);
   cudnnConvolutionForward(cudnn_handle(),
      &alpha,
      l.srcTensorDesc,
      input16,
      l.weightDesc,
      l.weights_gpu16,
      l.convDesc,
      l.fw_algo,
      state.workspace,
      l.workspace_size,
      &beta,
      l.dstTensorDesc,
      output16);
   cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
#else
    cudnnConvolutionForward(cudnn_handle(),
                &one,
                &alpha,
                l.srcTensorDesc,
                state.input,
                l.weightDesc,
@@ -97,9 +183,11 @@
                l.fw_algo,
                state.workspace,
                l.workspace_size,
                &one,
                &beta,
                l.dstTensorDesc,
                l.output_gpu);
#endif
#else
    int i;
@@ -123,6 +211,7 @@
    activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
    //if(l.dot > 0) dot_error_gpu(l);
    if(l.binary || l.xnor) swap_binary(&l);
   //cudaDeviceSynchronize(); // for correct profiling of performance
}
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
@@ -133,6 +222,9 @@
    if(l.batch_normalize){
        backward_batchnorm_layer_gpu(l, state);
        //axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.x_gpu, 1, l.delta_gpu, 1);
    } else {
        //axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.output_gpu, 1, l.delta_gpu, 1);
    }
    float *original_input = state.input;
@@ -155,6 +247,7 @@
    if(state.delta){
        if(l.binary || l.xnor) swap_binary(&l);
      // http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
        cudnnConvolutionBackwardData(cudnn_handle(),
                &one,
                l.weightDesc,
@@ -215,11 +308,18 @@
        cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
        cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
    }
    if (layer.adam){
        cuda_pull_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size);
        cuda_pull_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size);
    }
}
void push_convolutional_layer(convolutional_layer layer)
{
    cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
#ifdef CUDNN_HALF
   cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16);
#endif
    cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
    cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
    cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
@@ -228,21 +328,40 @@
        cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
        cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
    }
    if (layer.adam){
        cuda_push_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size);
        cuda_push_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size);
    }
}
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
    int size = layer.size*layer.size*layer.c*layer.n;
    axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
    axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
    scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
    if(layer.scales_gpu){
        axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
        scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
    }
    axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
    axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
    scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
    if(layer.adam){
        scal_ongpu(size, layer.B1, layer.m_gpu, 1);
        scal_ongpu(size, layer.B2, layer.v_gpu, 1);
        axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
        axpy_ongpu(size, -(1-layer.B1), layer.weight_updates_gpu, 1, layer.m_gpu, 1);
        mul_ongpu(size, layer.weight_updates_gpu, 1, layer.weight_updates_gpu, 1);
        axpy_ongpu(size, (1-layer.B2), layer.weight_updates_gpu, 1, layer.v_gpu, 1);
        adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
        fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
    }else{
        axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
        axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
        scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
    }
}