From eeb81f29375298dc1c8fc64374f40b6a743577a9 Mon Sep 17 00:00:00 2001
From: Alexey <AlexeyAB@users.noreply.github.com>
Date: Sun, 26 Nov 2017 17:47:37 +0000
Subject: [PATCH] Update Readme.md

---
 src/rnn.c |  102 ++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 100 insertions(+), 2 deletions(-)

diff --git a/src/rnn.c b/src/rnn.c
index c7c70a4..eca6f55 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -129,7 +129,6 @@
 void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
 {
     srand(time(0));
-    data_seed = time(0);
     unsigned char *text = 0;
     int *tokens = 0;
     size_t size;
@@ -199,7 +198,7 @@
             }
         }
 
-        if(i%100==0){
+        if(i%1000==0){
             char buff[256];
             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
             save_weights(net, buff);
@@ -280,6 +279,103 @@
     printf("\n");
 }
 
+void test_tactic_rnn(char *cfgfile, char *weightfile, int num, float temp, int rseed, char *token_file)
+{
+    char **tokens = 0;
+    if(token_file){
+        size_t n;
+        tokens = read_tokens(token_file, &n);
+    }
+
+    srand(rseed);
+    char *base = basecfg(cfgfile);
+    fprintf(stderr, "%s\n", base);
+
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    int inputs = get_network_input_size(net);
+
+    int i, j;
+    for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+    int c = 0;
+    float *input = calloc(inputs, sizeof(float));
+    float *out = 0;
+
+    while((c = getc(stdin)) != EOF){
+        input[c] = 1;
+        out = network_predict(net, input);
+        input[c] = 0;
+    }
+    for(i = 0; i < num; ++i){
+        for(j = 0; j < inputs; ++j){
+            if (out[j] < .0001) out[j] = 0;
+        }
+        int next = sample_array(out, inputs);
+        if(c == '.' && next == '\n') break;
+        c = next;
+        print_symbol(c, tokens);
+
+        input[c] = 1;
+        out = network_predict(net, input);
+        input[c] = 0;
+    }
+    printf("\n");
+}
+
+void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed)
+{
+    char *base = basecfg(cfgfile);
+    fprintf(stderr, "%s\n", base);
+
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    int inputs = get_network_input_size(net);
+
+    int count = 0;
+    int words = 1;
+    int c;
+    int len = strlen(seed);
+    float *input = calloc(inputs, sizeof(float));
+    int i;
+    for(i = 0; i < len; ++i){
+        c = seed[i];
+        input[(int)c] = 1;
+        network_predict(net, input);
+        input[(int)c] = 0;
+    }
+    float sum = 0;
+    c = getc(stdin);
+    float log2 = log(2);
+    int in = 0;
+    while(c != EOF){
+        int next = getc(stdin);
+        if(next == EOF) break;
+        if(next < 0 || next >= 255) error("Out of range character");
+
+        input[c] = 1;
+        float *out = network_predict(net, input);
+        input[c] = 0;
+
+        if(c == '.' && next == '\n') in = 0;
+        if(!in) {
+            if(c == '>' && next == '>'){
+                in = 1;
+                ++words;
+            }
+            c = next;
+            continue;
+        }
+        ++count;
+        sum += log(out[next])/log2;
+        c = next;
+        printf("%d %d Perplexity: %4.4f    Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words));
+    }
+}
+
 void valid_char_rnn(char *cfgfile, char *weightfile, char *seed)
 {
     char *base = basecfg(cfgfile);
@@ -389,6 +485,8 @@
     char *weights = (argc > 4) ? argv[4] : 0;
     if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized);
     else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed);
+    else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed);
     else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed);
     else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens);
+    else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, temp, rseed, tokens);
 }

--
Gitblit v1.10.0