From c604f2d9947413b75e5b35f6997458f1f8f89166 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 31 Jan 2016 23:52:03 +0000
Subject: [PATCH] ...
---
src/rnn.c | 52 ++++++++++++++++++++++++++++++++++++++++++++++++----
1 files changed, 48 insertions(+), 4 deletions(-)
diff --git a/src/rnn.c b/src/rnn.c
index d3e7e51..aee53ff 100644
--- a/src/rnn.c
+++ b/src/rnn.c
@@ -19,6 +19,12 @@
int i,j;
for(i = 0; i < batch; ++i){
int index = rand() %(len - steps - 1);
+ int done = 1;
+ while(!done){
+ index = rand() %(len - steps - 1);
+ while(index < len-steps-1 && text[index++] != '\n');
+ if (index < len-steps-1) done = 1;
+ }
for(j = 0; j < steps; ++j){
x[(j*batch + i)*256 + text[index + j]] = 1;
y[(j*batch + i)*256 + text[index + j + 1]] = 1;
@@ -48,13 +54,13 @@
srand(time(0));
data_seed = time(0);
char *base = basecfg(cfgfile);
- printf("%s\n", base);
+ fprintf(stderr, "%s\n", base);
float avg_loss = -1;
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int batch = net.batch;
int steps = net.time_steps;
int i = (*net.seen)/net.batch;
@@ -71,7 +77,7 @@
if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
+ fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
if(i%100==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -92,7 +98,7 @@
{
srand(rseed);
char *base = basecfg(cfgfile);
- printf("%s\n", base);
+ fprintf(stderr, "%s\n", base);
network net = parse_network_cfg(cfgfile);
if(weightfile){
@@ -128,6 +134,43 @@
printf("\n");
}
+void valid_char_rnn(char *cfgfile, char *weightfile, char *filename)
+{
+ FILE *fp = fopen(filename, "r");
+ //FILE *fp = fopen("data/ab.txt", "r");
+ //FILE *fp = fopen("data/grrm/asoiaf.txt", "r");
+
+ fseek(fp, 0, SEEK_END);
+ size_t size = ftell(fp);
+ fseek(fp, 0, SEEK_SET);
+
+ char *text = calloc(size, sizeof(char));
+ fread(text, 1, size, fp);
+ fclose(fp);
+
+ char *base = basecfg(cfgfile);
+ fprintf(stderr, "%s\n", base);
+
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+
+ int i;
+ char c;
+ float *input = calloc(256, sizeof(float));
+ float sum = 0;
+ for(i = 0; i < size-1; ++i){
+ c = text[i];
+ input[(int)c] = 1;
+ float *out = network_predict(net, input);
+ input[(int)c] = 0;
+ sum += log(out[(int)text[i+1]]);
+ }
+ printf("Log Probability: %f\n", sum);
+}
+
+
void run_char_rnn(int argc, char **argv)
{
if(argc < 4){
@@ -143,5 +186,6 @@
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename);
+ else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, filename);
else if(0==strcmp(argv[2], "test")) test_char_rnn(cfg, weights, len, seed, temp, rseed);
}
--
Gitblit v1.10.0