From 393dc8eb6f3a9dd92ec665200444186c1addc5d2 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 09 Sep 2015 19:48:40 +0000
Subject: [PATCH] stable

---
 src/parser.c |   74 ++++++++++++++++++++++++++++++++-----
 1 files changed, 64 insertions(+), 10 deletions(-)

diff --git a/src/parser.c b/src/parser.c
index ad324e9..94dc0fa 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -169,7 +169,7 @@
     int rescore = option_find_int(options, "rescore", 0);
     int joint = option_find_int(options, "joint", 0);
     int objectness = option_find_int(options, "objectness", 0);
-    int background = 0;
+    int background = option_find_int(options, "background", 0);
     detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, joint, rescore, background, objectness);
     return layer;
 }
@@ -189,7 +189,8 @@
 {
     char *type_s = option_find_str(options, "type", "sse");
     COST_TYPE type = get_cost_type(type_s);
-    cost_layer layer = make_cost_layer(params.batch, params.inputs, type);
+    float scale = option_find_float_quiet(options, "scale",1);
+    cost_layer layer = make_cost_layer(params.batch, params.inputs, type, scale);
     return layer;
 }
 
@@ -305,6 +306,18 @@
     return layer;
 }
 
+learning_rate_policy get_policy(char *s)
+{
+    if (strcmp(s, "poly")==0) return POLY;
+    if (strcmp(s, "constant")==0) return CONSTANT;
+    if (strcmp(s, "step")==0) return STEP;
+    if (strcmp(s, "exp")==0) return EXP;
+    if (strcmp(s, "sigmoid")==0) return SIG;
+    if (strcmp(s, "steps")==0) return STEPS;
+    fprintf(stderr, "Couldn't find policy %s, going with constant\n", s);
+    return CONSTANT;
+}
+
 void parse_net_options(list *options, network *net)
 {
     net->batch = option_find_int(options, "batch",1);
@@ -319,7 +332,47 @@
     net->w = option_find_int_quiet(options, "width",0);
     net->c = option_find_int_quiet(options, "channels",0);
     net->inputs = option_find_int_quiet(options, "inputs", net->h * net->w * net->c);
+
     if(!net->inputs && !(net->h && net->w && net->c)) error("No input parameters supplied");
+
+    char *policy_s = option_find_str(options, "policy", "constant");
+    net->policy = get_policy(policy_s);
+    if(net->policy == STEP){
+        net->step = option_find_int(options, "step", 1);
+        net->scale = option_find_float(options, "scale", 1);
+    } else if (net->policy == STEPS){
+        char *l = option_find(options, "steps");   
+        char *p = option_find(options, "scales");   
+        if(!l || !p) error("STEPS policy must have steps and scales in cfg file");
+
+        int len = strlen(l);
+        int n = 1;
+        int i;
+        for(i = 0; i < len; ++i){
+            if (l[i] == ',') ++n;
+        }
+        int *steps = calloc(n, sizeof(int));
+        float *scales = calloc(n, sizeof(float));
+        for(i = 0; i < n; ++i){
+            int step    = atoi(l);
+            float scale = atof(p);
+            l = strchr(l, ',')+1;
+            p = strchr(p, ',')+1;
+            steps[i] = step;
+            scales[i] = scale;
+        }
+        net->scales = scales;
+        net->steps = steps;
+        net->num_steps = n;
+    } else if (net->policy == EXP){
+        net->gamma = option_find_float(options, "gamma", 1);
+    } else if (net->policy == SIG){
+        net->gamma = option_find_float(options, "gamma", 1);
+        net->step = option_find_int(options, "step", 1);
+    } else if (net->policy == POLY){
+        net->power = option_find_float(options, "power", 1);
+    }
+    net->max_batches = option_find_int(options, "max_batches", 0);
 }
 
 network parse_network_cfg(char *filename)
@@ -377,10 +430,10 @@
             l = parse_dropout(options, params);
             l.output = net.layers[count-1].output;
             l.delta = net.layers[count-1].delta;
-            #ifdef GPU
+#ifdef GPU
             l.output_gpu = net.layers[count-1].output_gpu;
             l.delta_gpu = net.layers[count-1].delta_gpu;
-            #endif
+#endif
         }else{
             fprintf(stderr, "Type not recognized: %s\n", s->type);
         }
@@ -532,7 +585,7 @@
     fwrite(&net.learning_rate, sizeof(float), 1, fp);
     fwrite(&net.momentum, sizeof(float), 1, fp);
     fwrite(&net.decay, sizeof(float), 1, fp);
-    fwrite(&net.seen, sizeof(int), 1, fp);
+    fwrite(net.seen, sizeof(int), 1, fp);
 
     int i,j,k;
     for(i = 0; i < net.n; ++i){
@@ -571,7 +624,7 @@
     fwrite(&net.learning_rate, sizeof(float), 1, fp);
     fwrite(&net.momentum, sizeof(float), 1, fp);
     fwrite(&net.decay, sizeof(float), 1, fp);
-    fwrite(&net.seen, sizeof(int), 1, fp);
+    fwrite(net.seen, sizeof(int), 1, fp);
 
     int i;
     for(i = 0; i < net.n && i < cutoff; ++i){
@@ -620,10 +673,11 @@
     FILE *fp = fopen(filename, "r");
     if(!fp) file_error(filename);
 
-    fread(&net->learning_rate, sizeof(float), 1, fp);
-    fread(&net->momentum, sizeof(float), 1, fp);
-    fread(&net->decay, sizeof(float), 1, fp);
-    fread(&net->seen, sizeof(int), 1, fp);
+    float garbage;
+    fread(&garbage, sizeof(float), 1, fp);
+    fread(&garbage, sizeof(float), 1, fp);
+    fread(&garbage, sizeof(float), 1, fp);
+    fread(net->seen, sizeof(int), 1, fp);
 
     int i;
     for(i = 0; i < net->n && i < cutoff; ++i){

--
Gitblit v1.10.0