AlexeyAB
2018-02-22 dda993f3dd3c753dfd580d485b39c1001830fee4
src/convolutional_kernels.cu
@@ -74,6 +74,38 @@
    check_error(cudaPeekAtLastError());
}
__global__ void cuda_f32_to_f16(float* input_f32, size_t size, half *output_f16)
{
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
   if (idx < size) output_f16[idx] = input_f32[idx];
}
void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) {
   cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16);
}
__global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
{
   int idx = blockIdx.x * blockDim.x + threadIdx.x;
   if (idx < size) output_f32[idx] = input_f16[idx];
}
void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) {
   cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32);
}
half *cuda_make_f16_from_f32_array(float *src, size_t n)
{
   half *dst16;
   size_t size = sizeof(half)*n;
   check_error(cudaMalloc((void **)&dst16, size));
   if (src) {
      cuda_convert_f32_to_f16(src, n, dst16);
   }
   if (!dst16) error("Cuda malloc failed\n");
   return dst16;
}
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
@@ -90,9 +122,57 @@
    }
#ifdef CUDNN
    float one = 1;
   //float one = 1;  // alpha[0], beta[0] is float for HALF and FLOAT
   float alpha = 1, beta = 0;
#ifdef CUDNN_HALF
   // Note: For improved performance it is advised to use beta[0] = 0.0.
   // For Tensor Core: cudnnSetConvolutionMathType() where cudnnMathType_t mathType = CUDNN_TENSOR_OP_MATH;
   // 1. or CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM and use CUDNN_DATA_HALF
   // 2. or CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED
   // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
   const size_t input16_size = l.batch*l.c*l.w*l.h;
   static size_t max_input16_size = input16_size;
   static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size);
   const size_t output16_size = l.batch*l.out_c*l.out_h*l.out_w;
   static size_t max_output16_size = output16_size;
   static half* output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
   if (max_input16_size < input16_size) {
      max_input16_size = input16_size;
      cuda_free((float *)input16);
      input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size);
   }
   if (max_output16_size < output16_size) {
      max_output16_size = output16_size;
      cuda_free((float *)output16);
      output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
   }
   cuda_convert_f32_to_f16(state.input, input16_size, input16);
   cudnnConvolutionForward(cudnn_handle(),
      &alpha,
      l.srcTensorDesc,
      input16,
      l.weightDesc,
      l.weights_gpu16,
      l.convDesc,
      l.fw_algo,
      state.workspace,
      l.workspace_size,
      &beta,
      l.dstTensorDesc,
      output16);
   cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
#else
    cudnnConvolutionForward(cudnn_handle(),
                &one,
                &alpha,
                l.srcTensorDesc,
                state.input,
                l.weightDesc,
@@ -101,9 +181,11 @@
                l.fw_algo,
                state.workspace,
                l.workspace_size,
                &one,
                &beta,
                l.dstTensorDesc,
                l.output_gpu);
#endif
#else
    int i;
@@ -232,6 +314,9 @@
void push_convolutional_layer(convolutional_layer layer)
{
    cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
#ifdef CUDNN_HALF
   cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16);
#endif
    cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
    cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
    cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);