| | |
| | | } |
| | | } |
| | | |
| | | if(i%100==0){ |
| | | if(i%1000==0){ |
| | | char buff[256]; |
| | | sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); |
| | | save_weights(net, buff); |
| | |
| | | printf("\n"); |
| | | } |
| | | |
| | | void test_tactic_rnn(char *cfgfile, char *weightfile, int num, char *seed, float temp, int rseed, char *token_file) |
| | | { |
| | | char **tokens = 0; |
| | | if(token_file){ |
| | | size_t n; |
| | | tokens = read_tokens(token_file, &n); |
| | | } |
| | | |
| | | srand(rseed); |
| | | char *base = basecfg(cfgfile); |
| | | fprintf(stderr, "%s\n", base); |
| | | |
| | | network net = parse_network_cfg(cfgfile); |
| | | if(weightfile){ |
| | | load_weights(&net, weightfile); |
| | | } |
| | | int inputs = get_network_input_size(net); |
| | | |
| | | int i, j; |
| | | for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp; |
| | | int c = 0; |
| | | int len = strlen(seed); |
| | | float *input = calloc(inputs, sizeof(float)); |
| | | float *out; |
| | | |
| | | while((c = getc(stdin)) != EOF){ |
| | | input[c] = 1; |
| | | out = network_predict(net, input); |
| | | input[c] = 0; |
| | | } |
| | | for(i = 0; i < num; ++i){ |
| | | for(j = 0; j < inputs; ++j){ |
| | | if (out[j] < .0001) out[j] = 0; |
| | | } |
| | | int next = sample_array(out, inputs); |
| | | if(c == '.' && next == '\n') break; |
| | | c = next; |
| | | print_symbol(c, tokens); |
| | | |
| | | input[c] = 1; |
| | | out = network_predict(net, input); |
| | | input[c] = 0; |
| | | } |
| | | printf("\n"); |
| | | } |
| | | |
| | | void valid_tactic_rnn(char *cfgfile, char *weightfile, char *seed) |
| | | { |
| | | char *base = basecfg(cfgfile); |
| | | fprintf(stderr, "%s\n", base); |
| | | |
| | | network net = parse_network_cfg(cfgfile); |
| | | if(weightfile){ |
| | | load_weights(&net, weightfile); |
| | | } |
| | | int inputs = get_network_input_size(net); |
| | | |
| | | int count = 0; |
| | | int words = 1; |
| | | int c; |
| | | int len = strlen(seed); |
| | | float *input = calloc(inputs, sizeof(float)); |
| | | int i; |
| | | for(i = 0; i < len; ++i){ |
| | | c = seed[i]; |
| | | input[(int)c] = 1; |
| | | network_predict(net, input); |
| | | input[(int)c] = 0; |
| | | } |
| | | float sum = 0; |
| | | c = getc(stdin); |
| | | float log2 = log(2); |
| | | int in = 0; |
| | | while(c != EOF){ |
| | | int next = getc(stdin); |
| | | if(next == EOF) break; |
| | | if(next < 0 || next >= 255) error("Out of range character"); |
| | | |
| | | input[c] = 1; |
| | | float *out = network_predict(net, input); |
| | | input[c] = 0; |
| | | |
| | | if(c == '.' && next == '\n') in = 0; |
| | | if(!in) { |
| | | if(c == '>' && next == '>'){ |
| | | in = 1; |
| | | ++words; |
| | | } |
| | | c = next; |
| | | continue; |
| | | } |
| | | ++count; |
| | | sum += log(out[next])/log2; |
| | | c = next; |
| | | printf("%d %d Perplexity: %4.4f Word Perplexity: %4.4f\n", count, words, pow(2, -sum/count), pow(2, -sum/words)); |
| | | } |
| | | } |
| | | |
| | | void valid_char_rnn(char *cfgfile, char *weightfile, char *seed) |
| | | { |
| | | char *base = basecfg(cfgfile); |
| | |
| | | char *weights = (argc > 4) ? argv[4] : 0; |
| | | if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear, tokenized); |
| | | else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed); |
| | | else if(0==strcmp(argv[2], "validtactic")) valid_tactic_rnn(cfg, weights, seed); |
| | | else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed); |
| | | else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed, tokens); |
| | | else if(0==strcmp(argv[2], "generatetactic")) test_tactic_rnn(cfg, weights, len, seed, temp, rseed, tokens); |
| | | } |