From 9942d484122c346650bb5431fd209d9437b5310a Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 11 May 2016 17:45:50 +0000
Subject: [PATCH] stable

---
 src/parser.c |   33 +++++++++++++++++++++++++++++++--
 1 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/parser.c b/src/parser.c
index 6c88fd5..b900ad7 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -852,6 +852,18 @@
     fwrite(l.filters, sizeof(float), num, fp);
 }
 
+void save_batchnorm_weights(layer l, FILE *fp)
+{
+#ifdef GPU
+    if(gpu_index >= 0){
+        pull_batchnorm_layer(l);
+    }
+#endif
+    fwrite(l.scales, sizeof(float), l.c, fp);
+    fwrite(l.rolling_mean, sizeof(float), l.c, fp);
+    fwrite(l.rolling_variance, sizeof(float), l.c, fp);
+}
+
 void save_connected_weights(layer l, FILE *fp)
 {
 #ifdef GPU
@@ -889,6 +901,8 @@
             save_convolutional_weights(l, fp);
         } if(l.type == CONNECTED){
             save_connected_weights(l, fp);
+        } if(l.type == BATCHNORM){
+            save_batchnorm_weights(l, fp);
         } if(l.type == RNN){
             save_connected_weights(*(l.input_layer), fp);
             save_connected_weights(*(l.self_layer), fp);
@@ -943,8 +957,8 @@
     if(transpose){
         transpose_matrix(l.weights, l.inputs, l.outputs);
     }
-        //printf("Biases: %f mean %f variance\n", mean_array(l.biases, l.outputs), variance_array(l.biases, l.outputs));
-        //printf("Weights: %f mean %f variance\n", mean_array(l.weights, l.outputs*l.inputs), variance_array(l.weights, l.outputs*l.inputs));
+    //printf("Biases: %f mean %f variance\n", mean_array(l.biases, l.outputs), variance_array(l.biases, l.outputs));
+    //printf("Weights: %f mean %f variance\n", mean_array(l.weights, l.outputs*l.inputs), variance_array(l.weights, l.outputs*l.inputs));
     if (l.batch_normalize && (!l.dontloadscales)){
         fread(l.scales, sizeof(float), l.outputs, fp);
         fread(l.rolling_mean, sizeof(float), l.outputs, fp);
@@ -960,6 +974,18 @@
 #endif
 }
 
+void load_batchnorm_weights(layer l, FILE *fp)
+{
+    fread(l.scales, sizeof(float), l.c, fp);
+    fread(l.rolling_mean, sizeof(float), l.c, fp);
+    fread(l.rolling_variance, sizeof(float), l.c, fp);
+#ifdef GPU
+    if(gpu_index >= 0){
+        push_batchnorm_layer(l);
+    }
+#endif
+}
+
 void load_convolutional_weights_binary(layer l, FILE *fp)
 {
     fread(l.biases, sizeof(float), l.n, fp);
@@ -1053,6 +1079,9 @@
         if(l.type == CONNECTED){
             load_connected_weights(l, fp, transpose);
         }
+        if(l.type == BATCHNORM){
+            load_batchnorm_weights(l, fp);
+        }
         if(l.type == CRNN){
             load_convolutional_weights(*(l.input_layer), fp);
             load_convolutional_weights(*(l.self_layer), fp);

--
Gitblit v1.10.0