From cad4d1618fee74471d335314cb77070fee951a42 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Sun, 25 Feb 2018 13:29:44 +0000
Subject: [PATCH] Added support for Tensor Cores CC >= 7.0 (V100). For FP16/32 (mixed precision) define CUDNN_HALF should be used.

---
 src/convolutional_kernels.cu |  122 ++++++++++++++++++++++++++++++++++++----
 1 files changed, 108 insertions(+), 14 deletions(-)

diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 3b2a349..9d88a88 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -81,8 +81,8 @@
 	//if (idx < size) *((unsigned short *)output_f16 + idx) = __float2half(input_f32[idx]);
 }
 
-void cuda_convert_f32_to_f16(float* input_f32, size_t size, half *output_f16) {
-	cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, output_f16);
+void cuda_convert_f32_to_f16(float* input_f32, size_t size, float *output_f16) {
+	cuda_f32_to_f16 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f32, size, (half *)output_f16);
 }
 
 __global__ void cuda_f16_to_f32(half* input_f16, size_t size, float *output_f32)
@@ -92,8 +92,8 @@
 	//if (idx < size) output_f32[idx] = __half2float(*((unsigned short *)input_f16 + idx));
 }
 
-void cuda_convert_f16_to_f32(half* input_f16, size_t size, float *output_f32) {
-	cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> (input_f16, size, output_f32);
+void cuda_convert_f16_to_f32(float* input_f16, size_t size, float *output_f32) {
+	cuda_f16_to_f32 <<< size / BLOCK + 1, BLOCK, 0, get_cuda_stream() >>> ((half *)input_f16, size, output_f32);
 }
 
 half *cuda_make_f16_from_f32_array(float *src, size_t n)
@@ -102,7 +102,7 @@
 	size_t size = sizeof(half)*n;
 	check_error(cudaMalloc((void **)&dst16, size));
 	if (src) {
-		cuda_convert_f32_to_f16(src, n, dst16);
+		cuda_convert_f32_to_f16(src, n, (float *)dst16);
 	}
 	if (!dst16) error("Cuda malloc failed\n");
 	return dst16;
@@ -124,8 +124,8 @@
     }
 
 #ifdef CUDNN
-	//float one = 1;	// alpha[0], beta[0] is float for HALF and FLOAT
-	float alpha = 1, beta = 0;
+	float one = 1;	// alpha[0], beta[0] is float for HALF and FLOAT
+	float alpha = 1, beta = 0; 
 
 #ifdef CUDNN_HALF
 	// Note: For improved performance it is advised to use beta[0] = 0.0. 
@@ -154,8 +154,9 @@
 		output16 = cuda_make_f16_from_f32_array(NULL, max_output16_size);
 	}
 
-	cuda_convert_f32_to_f16(state.input, input16_size, input16);
+	cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16);
 
+	//fill_ongpu(output16_size / 2, 0, (float *)output16, 1);
 	cudnnConvolutionForward(cudnn_handle(),
 		&alpha,
 		l.srcTensorDesc,
@@ -170,11 +171,12 @@
 		l.dstTensorDesc,
 		output16);
 	
-	cuda_convert_f16_to_f32(output16, output16_size, l.output_gpu);
+	cuda_convert_f16_to_f32((float *)output16, output16_size, l.output_gpu);
+
 #else
 
     cudnnConvolutionForward(cudnn_handle(),
-                &alpha,
+                &one,
                 l.srcTensorDesc,
                 state.input,
                 l.weightDesc,
@@ -183,7 +185,7 @@
                 l.fw_algo,
                 state.workspace,
                 l.workspace_size,
-                &beta,
+                &one,
                 l.dstTensorDesc,
                 l.output_gpu);
 #endif
@@ -230,7 +232,88 @@
 
     if(l.xnor) state.input = l.binary_input_gpu;
 #ifdef CUDNN
-    float one = 1;
+	float one = 1;
+	float alpha = 1, beta = 0;
+
+#ifdef CUDNN_HALF
+		
+	const size_t input16_size = l.batch*l.c*l.w*l.h;
+	static size_t max_input16_size = input16_size;
+	static half* input16 = cuda_make_f16_from_f32_array(NULL, max_input16_size);
+
+	const size_t delta16_size = l.batch*l.n*l.out_w*l.out_h;
+	static size_t max_delta16_size = delta16_size;
+	static half* delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size);
+
+	if (max_input16_size < input16_size) {
+		max_input16_size = input16_size;
+		cuda_free((float *)input16);
+		input16 = cuda_make_f16_from_f32_array(state.input, max_input16_size);
+	}
+
+	if (max_delta16_size < delta16_size) {
+		max_delta16_size = delta16_size;
+		cuda_free((float *)delta16);
+		delta16 = cuda_make_f16_from_f32_array(NULL, max_delta16_size);
+	}
+
+	cuda_convert_f32_to_f16(state.input, input16_size, (float *)input16);
+	cuda_convert_f32_to_f16(l.delta_gpu, delta16_size, (float *)delta16);
+	
+	// convert input: state.input (x), l.delta_gpu (y) from fp32 to fp16
+	// get output: l.weight_updates_gpu (dw) and convert it to fp32 (ONLY if it is fp16)
+
+	// calculate conv weight updates
+	// Already: l.weight_updates_gpu = (l.weight_updates_gpu - l.weight*decay*batch*subdivision)*momentum
+	//   so we should copy f32 to f16, or compute: f16=(w_up - w*d*b*s)*m
+	cuda_convert_f32_to_f16(l.weight_updates_gpu, l.c*l.n*l.size*l.size, l.weight_updates_gpu16);
+
+	cudnnConvolutionBackwardFilter(cudnn_handle(),
+		&one,
+		l.srcTensorDesc,
+		input16, //state.input,
+		l.ddstTensorDesc,
+		delta16, //l.delta_gpu,
+		l.convDesc,
+		l.bf_algo,
+		state.workspace,
+		l.workspace_size,
+		&one,
+		l.dweightDesc,
+		l.weight_updates_gpu16);	// l.weight_updates_gpu);
+
+	cuda_convert_f16_to_f32(l.weight_updates_gpu16, l.c*l.n*l.size*l.size, l.weight_updates_gpu);
+
+	if (state.delta) {
+		if (l.binary || l.xnor) swap_binary(&l);
+
+		// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
+		// calculate delta for the next layer
+		// convert input: l.weights_gpu (w), l.delta_gpu (dy) from fp32 to fp16
+		// get output: state.delta (dx) and convert it to fp32 (ONLY if it is fp16)	
+		cudnnConvolutionBackwardData(cudnn_handle(),
+			&alpha,
+			l.weightDesc,
+			l.weights_gpu16, //l.weights_gpu,
+			l.ddstTensorDesc,
+			delta16, //l.delta_gpu,
+			l.convDesc,
+			l.bd_algo,
+			state.workspace,
+			l.workspace_size,
+			&beta,
+			l.dsrcTensorDesc,
+			input16);	// state.delta);
+
+		cuda_convert_f16_to_f32((float *)input16, input16_size, state.delta);		
+
+		if (l.binary || l.xnor) swap_binary(&l);
+		if (l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
+	}
+#else	// CUDNN_HALF
+
+	// calculate conv weight updates
+	// if used: beta=1 then loss decreases faster
     cudnnConvolutionBackwardFilter(cudnn_handle(),
             &one,
             l.srcTensorDesc,
@@ -248,6 +331,7 @@
     if(state.delta){
         if(l.binary || l.xnor) swap_binary(&l);
 		// http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#cudnnConvolutionBackwardData
+		// calculate delta for the next layer
         cudnnConvolutionBackwardData(cudnn_handle(),
                 &one,
                 l.weightDesc,
@@ -265,7 +349,9 @@
         if(l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
     }
 
-#else
+#endif	// CUDNN_HALF
+
+#else	// CUDNN
     int m = l.n;
     int n = l.size*l.size*l.c;
     int k = l.out_w*l.out_h;
@@ -318,7 +404,7 @@
 {
     cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
 #ifdef CUDNN_HALF
-	cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, (half *)layer.weights_gpu16);
+	cuda_convert_f32_to_f16(layer.weights_gpu, layer.c*layer.n*layer.size*layer.size, layer.weights_gpu16);
 #endif
     cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
     cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
@@ -358,6 +444,14 @@
         adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
         fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
     }else{
+		// update weights:
+		// weights_gpu = weights_gpu*(1 - decay*lr) + weight_updates_gpu*lr / (batch*subdivision) =
+		//  weights_gpu*(1 - 0.0005*0.001) + weight_updates_gpu*0.001/(64*8) = 
+		//  weights_gpu * 0.999 999 5 + weight_updates_gpu * 0.000 001 953125
+		// 
+		// weight_updates_gpu = (weight_updates_gpu - weights_gpu*decay*batch*subdivision)*momentum = 
+		//  (weight_updates_gpu - weights_gpu * 0.0005 * 64 * 8) * 0.9 = 
+		//  weight_updates_gpu*0.9 - weights_gpu*0.2304
         axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
         axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
         scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);

--
Gitblit v1.10.0