From c604f2d9947413b75e5b35f6997458f1f8f89166 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 31 Jan 2016 23:52:03 +0000
Subject: [PATCH] ...

---
 src/rnn.c |   52 ++++++++++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 48 insertions(+), 4 deletions(-)

diff --git a/src/rnn.c b/src/rnn.c
index d3e7e51..aee53ff 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -19,6 +19,12 @@
     int i,j;
     for(i = 0; i < batch; ++i){
         int index = rand() %(len - steps - 1);
+        int done = 1;
+        while(!done){
+            index = rand() %(len - steps - 1);
+            while(index < len-steps-1 && text[index++] != '\n');
+            if (index < len-steps-1) done = 1;
+        }
         for(j = 0; j < steps; ++j){
             x[(j*batch + i)*256 + text[index + j]] = 1;
             y[(j*batch + i)*256 + text[index + j + 1]] = 1;
@@ -48,13 +54,13 @@
     srand(time(0));
     data_seed = time(0);
     char *base = basecfg(cfgfile);
-    printf("%s\n", base);
+    fprintf(stderr, "%s\n", base);
     float avg_loss = -1;
     network net = parse_network_cfg(cfgfile);
     if(weightfile){
         load_weights(&net, weightfile);
     }
-    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int batch = net.batch;
     int steps = net.time_steps;
     int i = (*net.seen)/net.batch;
@@ -71,7 +77,7 @@
         if (avg_loss < 0) avg_loss = loss;
         avg_loss = avg_loss*.9 + loss*.1;
 
-        printf("%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
+        fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
         if(i%100==0){
             char buff[256];
             sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -92,7 +98,7 @@
 {
     srand(rseed);
     char *base = basecfg(cfgfile);
-    printf("%s\n", base);
+    fprintf(stderr, "%s\n", base);
 
     network net = parse_network_cfg(cfgfile);
     if(weightfile){
@@ -128,6 +134,43 @@
     printf("\n");
 }
 
+void valid_char_rnn(char *cfgfile, char *weightfile, char *filename)
+{
+    FILE *fp = fopen(filename, "r");
+    //FILE *fp = fopen("data/ab.txt", "r");
+    //FILE *fp = fopen("data/grrm/asoiaf.txt", "r");
+
+    fseek(fp, 0, SEEK_END); 
+    size_t size = ftell(fp);
+    fseek(fp, 0, SEEK_SET); 
+
+    char *text = calloc(size, sizeof(char));
+    fread(text, 1, size, fp);
+    fclose(fp);
+
+    char *base = basecfg(cfgfile);
+    fprintf(stderr, "%s\n", base);
+
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    
+    int i;
+    char c;
+    float *input = calloc(256, sizeof(float));
+    float sum = 0;
+    for(i = 0; i < size-1; ++i){
+        c = text[i];
+        input[(int)c] = 1;
+        float *out = network_predict(net, input);
+        input[(int)c] = 0;
+        sum += log(out[(int)text[i+1]]);
+    }
+    printf("Log Probability: %f\n", sum);
+}
+
+
 void run_char_rnn(int argc, char **argv)
 {
     if(argc < 4){
@@ -143,5 +186,6 @@
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
     if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename);
+    else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, filename);
     else if(0==strcmp(argv[2], "test")) test_char_rnn(cfg, weights, len, seed, temp, rseed);
 }

--
Gitblit v1.10.0