From 13209df7bb53de19aa3f82e870db11eb5b7587f1 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 13 May 2016 18:59:43 +0000
Subject: [PATCH] art, cudnn
---
src/parser.c | 10 +++++++++-
1 files changed, 9 insertions(+), 1 deletions(-)
diff --git a/src/parser.c b/src/parser.c
index b900ad7..d5288aa 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -432,6 +432,7 @@
learning_rate_policy get_policy(char *s)
{
+ if (strcmp(s, "random")==0) return RANDOM;
if (strcmp(s, "poly")==0) return POLY;
if (strcmp(s, "constant")==0) return CONSTANT;
if (strcmp(s, "step")==0) return STEP;
@@ -497,7 +498,7 @@
} else if (net->policy == SIG){
net->gamma = option_find_float(options, "gamma", 1);
net->step = option_find_int(options, "step", 1);
- } else if (net->policy == POLY){
+ } else if (net->policy == POLY || net->policy == RANDOM){
net->power = option_find_float(options, "power", 1);
}
net->max_batches = option_find_int(options, "max_batches", 0);
@@ -523,6 +524,7 @@
params.batch = net.batch;
params.time_steps = net.time_steps;
+ size_t workspace_size = 0;
n = n->next;
int count = 0;
free_section(s);
@@ -583,6 +585,7 @@
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
option_unused(options);
net.layers[count] = l;
+ if (l.workspace_size > workspace_size) workspace_size = l.workspace_size;
free_section(s);
n = n->next;
++count;
@@ -596,6 +599,11 @@
free_list(sections);
net.outputs = get_network_output_size(net);
net.output = get_network_output(net);
+ if(workspace_size){
+#ifdef GPU
+ net.workspace = cuda_make_array(0, (workspace_size-1)/sizeof(float)+1);
+#endif
+ }
return net;
}
--
Gitblit v1.10.0