From 7e9e289b164e9bbfb77f644b1cbfac48bc8c8408 Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Tue, 11 Sep 2018 00:51:08 +0000
Subject: [PATCH] training w/ full yolo cfg

---
 src/batchnorm_layer.c |  130 +++++++++++++++++++++++++++++++++++++-----
 1 files changed, 113 insertions(+), 17 deletions(-)

diff --git a/src/batchnorm_layer.c b/src/batchnorm_layer.c
index 6ea4040..3fa129d 100644
--- a/src/batchnorm_layer.c
+++ b/src/batchnorm_layer.c
@@ -28,7 +28,13 @@
 
     layer.rolling_mean = calloc(c, sizeof(float));
     layer.rolling_variance = calloc(c, sizeof(float));
+
+    layer.forward = forward_batchnorm_layer;
+    layer.backward = backward_batchnorm_layer;
 #ifdef GPU
+    layer.forward_gpu = forward_batchnorm_layer_gpu;
+    layer.backward_gpu = backward_batchnorm_layer_gpu;
+
     layer.output_gpu =  cuda_make_array(layer.output, h * w * c * batch);
     layer.delta_gpu =   cuda_make_array(layer.delta, h * w * c * batch);
 
@@ -46,6 +52,12 @@
 
     layer.x_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
     layer.x_norm_gpu = cuda_make_array(layer.output, layer.batch*layer.outputs);
+#ifdef CUDNN
+    cudnnCreateTensorDescriptor(&layer.normTensorDesc);
+    cudnnCreateTensorDescriptor(&layer.normDstTensorDesc);
+    cudnnSetTensor4dDescriptor(layer.normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, layer.batch, layer.out_c, layer.out_h, layer.out_w);
+    cudnnSetTensor4dDescriptor(layer.normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, layer.out_c, 1, 1);
+#endif
 #endif
     return layer;
 }
@@ -121,48 +133,131 @@
         l.out_h = l.out_w = 1;
     }
     if(state.train){
-        mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);   
-        variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);   
+        mean_cpu(l.output, l.batch, l.out_c, l.out_h*l.out_w, l.mean);
+        variance_cpu(l.output, l.mean, l.batch, l.out_c, l.out_h*l.out_w, l.variance);
+
+        scal_cpu(l.out_c, .9, l.rolling_mean, 1);
+        axpy_cpu(l.out_c, .1, l.mean, 1, l.rolling_mean, 1);
+        scal_cpu(l.out_c, .9, l.rolling_variance, 1);
+        axpy_cpu(l.out_c, .1, l.variance, 1, l.rolling_variance, 1);
+
+        copy_cpu(l.outputs*l.batch, l.output, 1, l.x, 1);
         normalize_cpu(l.output, l.mean, l.variance, l.batch, l.out_c, l.out_h*l.out_w);   
+        copy_cpu(l.outputs*l.batch, l.output, 1, l.x_norm, 1);
     } else {
         normalize_cpu(l.output, l.rolling_mean, l.rolling_variance, l.batch, l.out_c, l.out_h*l.out_w);
     }
     scale_bias(l.output, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
 }
 
-void backward_batchnorm_layer(const layer layer, network_state state)
+void backward_batchnorm_layer(const layer l, network_state state)
 {
+    backward_scale_cpu(l.x_norm, l.delta, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates);
+
+    scale_bias(l.delta, l.scales, l.batch, l.out_c, l.out_h*l.out_w);
+
+    mean_delta_cpu(l.delta, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta);
+    variance_delta_cpu(l.x, l.delta, l.mean, l.variance, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta);
+    normalize_delta_cpu(l.x, l.mean, l.variance, l.mean_delta, l.variance_delta, l.batch, l.out_c, l.out_w*l.out_h, l.delta);
+    if(l.type == BATCHNORM) copy_cpu(l.outputs*l.batch, l.delta, 1, state.delta, 1);
 }
 
 #ifdef GPU
+
+void pull_batchnorm_layer(layer l)
+{
+    cuda_pull_array(l.scales_gpu, l.scales, l.c);
+    cuda_pull_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
+    cuda_pull_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
+}
+void push_batchnorm_layer(layer l)
+{
+    cuda_push_array(l.scales_gpu, l.scales, l.c);
+    cuda_push_array(l.rolling_mean_gpu, l.rolling_mean, l.c);
+    cuda_push_array(l.rolling_variance_gpu, l.rolling_variance, l.c);
+}
+
 void forward_batchnorm_layer_gpu(layer l, network_state state)
 {
-    if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
-    if(l.type == CONNECTED){
-        l.out_c = l.outputs;
-        l.out_h = l.out_w = 1;
-    }
+    if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, state.input, 1, l.output_gpu, 1);
+    copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
     if (state.train) {
+#ifdef CUDNN
+        float one = 1;
+        float zero = 0;
+        cudnnBatchNormalizationForwardTraining(cudnn_handle(),
+            CUDNN_BATCHNORM_SPATIAL,
+            &one,
+            &zero,
+            l.normDstTensorDesc,
+            l.x_gpu,                // input
+            l.normDstTensorDesc,
+            l.output_gpu,            // output
+            l.normTensorDesc,
+            l.scales_gpu,
+            l.biases_gpu,
+            .01,
+            l.rolling_mean_gpu,        // output (should be FP32)
+            l.rolling_variance_gpu,    // output (should be FP32)
+            .00001,
+            l.mean_gpu,            // output (should be FP32)
+            l.variance_gpu);    // output (should be FP32)
+#else
         fast_mean_gpu(l.output_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.mean_gpu);
         fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.out_c, l.out_h*l.out_w, l.variance_gpu);
 
-        scal_ongpu(l.out_c, .95, l.rolling_mean_gpu, 1);
-        axpy_ongpu(l.out_c, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
-        scal_ongpu(l.out_c, .95, l.rolling_variance_gpu, 1);
-        axpy_ongpu(l.out_c, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
+        scal_ongpu(l.out_c, .99, l.rolling_mean_gpu, 1);
+        axpy_ongpu(l.out_c, .01, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
+        scal_ongpu(l.out_c, .99, l.rolling_variance_gpu, 1);
+        axpy_ongpu(l.out_c, .01, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
 
         copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
         normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
         copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
-    } else {
+
+        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+        add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
+#endif
+    }
+    else {
         normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
+        add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.out_c, l.out_w*l.out_h);
     }
 
-    scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
 }
 
-void backward_batchnorm_layer_gpu(const layer l, network_state state)
+void backward_batchnorm_layer_gpu(layer l, network_state state)
 {
+    if (!state.train) {
+        l.mean_gpu = l.rolling_mean_gpu;
+        l.variance_gpu = l.rolling_variance_gpu;
+    }
+#ifdef CUDNN
+    float one = 1;
+    float zero = 0;
+    cudnnBatchNormalizationBackward(cudnn_handle(),
+        CUDNN_BATCHNORM_SPATIAL,
+        &one,
+        &zero,
+        &one,
+        &one,
+        l.normDstTensorDesc,
+        l.x_gpu,                // input
+        l.normDstTensorDesc,
+        l.delta_gpu,            // input
+        l.normDstTensorDesc,
+        l.x_norm_gpu,            // output
+        l.normTensorDesc,
+        l.scales_gpu,            // output (should be FP32)
+        l.scale_updates_gpu,    // output (should be FP32)
+        l.bias_updates_gpu,        // output (should be FP32)
+        .00001,
+        l.mean_gpu,                // input (should be FP32)
+        l.variance_gpu);        // input (should be FP32)
+    copy_ongpu(l.outputs*l.batch, l.x_norm_gpu, 1, l.delta_gpu, 1);
+#else
+    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h);
     backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.scale_updates_gpu);
 
     scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.out_c, l.out_h*l.out_w);
@@ -170,6 +265,7 @@
     fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.mean_delta_gpu);
     fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.variance_delta_gpu);
     normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.out_c, l.out_w*l.out_h, l.delta_gpu);
-    if(l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
-}
 #endif
+    if (l.type == BATCHNORM) copy_ongpu(l.outputs*l.batch, l.delta_gpu, 1, state.delta, 1);
+}
+#endif
\ No newline at end of file

--
Gitblit v1.10.0