From 71a9929af6c3d3ffb9527bb921c5cc4a20971ff6 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Fri, 17 Mar 2017 22:47:21 +0000
Subject: [PATCH] Fixed x & y coords less than 0
---
src/rnn.c | 102 ++++++++++++++++++++++++++++++++++++++++++++++++++-
1 files changed, 100 insertions(+), 2 deletions(-)
diff --git a/src/rnn.c b/src/rnn.c
index c7c70a4..eca6f55 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -129,7 +129,6 @@
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
{
srand(time(0));
- data_seed = time(0);
unsigned char *text = 0;
int *tokens = 0;
size_t size;
@@ -199,7 +198,7 @@
}
}
- if(i%100==0){
+ if(i%1000==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
save_weights(net, buff);
@@ -280,6 +279,103 @@
printf("\n");
}
+void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
+{
+ char **tokens = 0;
+ if(token_file){
+ size_t n;
+ tokens = read_tokens(token_file, &n);
+ }
+
+ srand(rseed);
+ char *base = basecfg(cfgfile);
+ fprintf(stderr, "%s\n", base);
+
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ int inputs = get_network_input_size(net);
+
+ int i, j;
+ for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+ int c = 0;
+ float *input = calloc(inputs, sizeof(float));
+ float *out = 0;
+
+ while((c = getc(stdin)) != EOF){
+ input[c] = 1;
+ out = network_predict(net, input);
+ input[c] = 0;
+ }
+ for(i = 0; i < num; ++i){
+ for(j = 0; j < inputs; ++j){
+ if (out[j] < .0001) out[j] = 0;
+ }
+ int next = sample_array(out, inputs);
+ if(c == '.' && next == '\n') break;
+ c = next;
+ print_symbol(c, tokens);
+
+ input[c] = 1;
+ out = network_predict(net, input);
+ input[c] = 0;
+ }
+ printf("\n");
+}
+
+void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
+{
+ char *base = basecfg(cfgfile);
+ fprintf(stderr, "%s\n", base);
+
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ int inputs = get_network_input_size(net);
+
+ int count = 0;
+ int words = 1;
+ int c;
+ int len = strlen(seed);
+ float *input = calloc(inputs, sizeof(float));
+ int i;
+ for(i = 0; i < len; ++i){
+ c = seed[i];
+ input[(int)c] = 1;
+ network_predict(net, input);
+ input[(int)c] = 0;
+ }
+ float sum = 0;
+ c = getc(stdin);
+ float log2 = log(2);
+ int in = 0;
+ while(c != EOF){
+ int next = getc(stdin);
+ if(next == EOF) break;
+ if(next < 0 || next >= 255) error("Out of range character");
+
+ input[c] = 1;
+ float *out = network_predict(net, input);
+ input[c] = 0;
+
+ if(c == '.' && next == '\n') in = 0;
+ if(!in) {
+ if(c == '>' && next == '>'){
+ in = 1;
+ ++words;
+ }
+ c = next;
+ continue;
+ }
+ ++count;
+ sum += log(out[next])/log2;
+ c = next;
+ printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
+ }
+}
+
void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
{
char *base = basecfg(cfgfile);
@@ -389,6 +485,8 @@
char *weights = (argc > 4) ? argv[4] : 0;
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
+ else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
+ else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
}
--
Gitblit v1.10.0