| | |
| | | printf("\n"); |
| | | } |
| | | |
| | | void valid_char_rnn(char *cfgfile, char *weightfile) |
| | | void valid_char_rnn(char *cfgfile, char *weightfile, char *seed) |
| | | { |
| | | char *base = basecfg(cfgfile); |
| | | fprintf(stderr, "%s\n", base); |
| | |
| | | |
| | | int count = 0; |
| | | int c; |
| | | int len = strlen(seed); |
| | | float *input = calloc(inputs, sizeof(float)); |
| | | int i; |
| | | for(i = 0; i < 100; ++i){ |
| | | for(i = 0; i < len; ++i){ |
| | | c = seed[i]; |
| | | input[(int)c] = 1; |
| | | network_predict(net, input); |
| | | input[(int)c] = 0; |
| | | } |
| | | float sum = 0; |
| | | c = getc(stdin); |
| | | float log2 = log(2); |
| | | while(c != EOF){ |
| | | int next = getc(stdin); |
| | | if(next < 0 || next >= 255) error("Out of range character"); |
| | | if(next == EOF) break; |
| | | if(next < 0 || next >= 255) error("Out of range character"); |
| | | ++count; |
| | | input[c] = 1; |
| | | float *out = network_predict(net, input); |
| | |
| | | } |
| | | } |
| | | |
| | | void vec_char_rnn(char *cfgfile, char *weightfile, char *seed) |
| | | { |
| | | char *base = basecfg(cfgfile); |
| | | fprintf(stderr, "%s\n", base); |
| | | |
| | | network net = parse_network_cfg(cfgfile); |
| | | if(weightfile){ |
| | | load_weights(&net, weightfile); |
| | | } |
| | | int inputs = get_network_input_size(net); |
| | | |
| | | int c; |
| | | int seed_len = strlen(seed); |
| | | float *input = calloc(inputs, sizeof(float)); |
| | | int i; |
| | | char *line; |
| | | while((line=fgetl(stdin)) != 0){ |
| | | reset_rnn_state(net, 0); |
| | | for(i = 0; i < seed_len; ++i){ |
| | | c = seed[i]; |
| | | input[(int)c] = 1; |
| | | network_predict(net, input); |
| | | input[(int)c] = 0; |
| | | } |
| | | strip(line); |
| | | int str_len = strlen(line); |
| | | for(i = 0; i < str_len; ++i){ |
| | | c = line[i]; |
| | | input[(int)c] = 1; |
| | | network_predict(net, input); |
| | | input[(int)c] = 0; |
| | | } |
| | | c = ' '; |
| | | input[(int)c] = 1; |
| | | network_predict(net, input); |
| | | input[(int)c] = 0; |
| | | |
| | | layer l = net.layers[0]; |
| | | cuda_pull_array(l.output_gpu, l.output, l.outputs); |
| | | printf("%s", line); |
| | | for(i = 0; i < l.outputs; ++i){ |
| | | printf(",%g", l.output[i]); |
| | | } |
| | | printf("\n"); |
| | | } |
| | | } |
| | | |
| | | void run_char_rnn(int argc, char **argv) |
| | | { |
| | |
| | | return; |
| | | } |
| | | char *filename = find_char_arg(argc, argv, "-file", "data/shakespeare.txt"); |
| | | char *seed = find_char_arg(argc, argv, "-seed", "\n"); |
| | | char *seed = find_char_arg(argc, argv, "-seed", "\n\n"); |
| | | int len = find_int_arg(argc, argv, "-len", 1000); |
| | | float temp = find_float_arg(argc, argv, "-temp", .7); |
| | | int rseed = find_int_arg(argc, argv, "-srand", time(0)); |
| | |
| | | char *cfg = argv[3]; |
| | | char *weights = (argc > 4) ? argv[4] : 0; |
| | | if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename, clear); |
| | | else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights); |
| | | else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, seed); |
| | | else if(0==strcmp(argv[2], "vec")) vec_char_rnn(cfg, weights, seed); |
| | | else if(0==strcmp(argv[2], "generate")) test_char_rnn(cfg, weights, len, seed, temp, rseed); |
| | | } |