From 054e2b1954aafb15b0e983180dda309cfd5d831f Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 12 May 2016 20:36:11 +0000
Subject: [PATCH] not sure

---
 src/parser.c |   36 +++++++++++++++++++++++++++++++++---
 1 files changed, 33 insertions(+), 3 deletions(-)

diff --git a/src/parser.c b/src/parser.c
index 6c88fd5..5a9d0a3 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -432,6 +432,7 @@
 
 learning_rate_policy get_policy(char *s)
 {
+    if (strcmp(s, "random")==0) return RANDOM;
     if (strcmp(s, "poly")==0) return POLY;
     if (strcmp(s, "constant")==0) return CONSTANT;
     if (strcmp(s, "step")==0) return STEP;
@@ -497,7 +498,7 @@
     } else if (net->policy == SIG){
         net->gamma = option_find_float(options, "gamma", 1);
         net->step = option_find_int(options, "step", 1);
-    } else if (net->policy == POLY){
+    } else if (net->policy == POLY || net->policy == RANDOM){
         net->power = option_find_float(options, "power", 1);
     }
     net->max_batches = option_find_int(options, "max_batches", 0);
@@ -852,6 +853,18 @@
     fwrite(l.filters, sizeof(float), num, fp);
 }
 
+void save_batchnorm_weights(layer l, FILE *fp)
+{
+#ifdef GPU
+    if(gpu_index >= 0){
+        pull_batchnorm_layer(l);
+    }
+#endif
+    fwrite(l.scales, sizeof(float), l.c, fp);
+    fwrite(l.rolling_mean, sizeof(float), l.c, fp);
+    fwrite(l.rolling_variance, sizeof(float), l.c, fp);
+}
+
 void save_connected_weights(layer l, FILE *fp)
 {
 #ifdef GPU
@@ -889,6 +902,8 @@
             save_convolutional_weights(l, fp);
         } if(l.type == CONNECTED){
             save_connected_weights(l, fp);
+        } if(l.type == BATCHNORM){
+            save_batchnorm_weights(l, fp);
         } if(l.type == RNN){
             save_connected_weights(*(l.input_layer), fp);
             save_connected_weights(*(l.self_layer), fp);
@@ -943,8 +958,8 @@
     if(transpose){
         transpose_matrix(l.weights, l.inputs, l.outputs);
     }
-        //printf("Biases: %f mean %f variance\n", mean_array(l.biases, l.outputs), variance_array(l.biases, l.outputs));
-        //printf("Weights: %f mean %f variance\n", mean_array(l.weights, l.outputs*l.inputs), variance_array(l.weights, l.outputs*l.inputs));
+    //printf("Biases: %f mean %f variance\n", mean_array(l.biases, l.outputs), variance_array(l.biases, l.outputs));
+    //printf("Weights: %f mean %f variance\n", mean_array(l.weights, l.outputs*l.inputs), variance_array(l.weights, l.outputs*l.inputs));
     if (l.batch_normalize && (!l.dontloadscales)){
         fread(l.scales, sizeof(float), l.outputs, fp);
         fread(l.rolling_mean, sizeof(float), l.outputs, fp);
@@ -960,6 +975,18 @@
 #endif
 }
 
+void load_batchnorm_weights(layer l, FILE *fp)
+{
+    fread(l.scales, sizeof(float), l.c, fp);
+    fread(l.rolling_mean, sizeof(float), l.c, fp);
+    fread(l.rolling_variance, sizeof(float), l.c, fp);
+#ifdef GPU
+    if(gpu_index >= 0){
+        push_batchnorm_layer(l);
+    }
+#endif
+}
+
 void load_convolutional_weights_binary(layer l, FILE *fp)
 {
     fread(l.biases, sizeof(float), l.n, fp);
@@ -1053,6 +1080,9 @@
         if(l.type == CONNECTED){
             load_connected_weights(l, fp, transpose);
         }
+        if(l.type == BATCHNORM){
+            load_batchnorm_weights(l, fp);
+        }
         if(l.type == CRNN){
             load_convolutional_weights(*(l.input_layer), fp);
             load_convolutional_weights(*(l.self_layer), fp);

--
Gitblit v1.10.0