...
Joseph Redmon
2016-01-31 c604f2d9947413b75e5b35f6997458f1f8f89166
src/rnn.c
@@ -19,6 +19,12 @@
    int i,j;
    for(i = 0; i < batch; ++i){
        int index = rand() %(len - steps - 1);
        int done = 1;
        while(!done){
            index = rand() %(len - steps - 1);
            while(index < len-steps-1 && text[index++] != '\n');
            if (index < len-steps-1) done = 1;
        }
        for(j = 0; j < steps; ++j){
            x[(j*batch + i)*256 + text[index + j]] = 1;
            y[(j*batch + i)*256 + text[index + j + 1]] = 1;
@@ -48,13 +54,13 @@
    srand(time(0));
    data_seed = time(0);
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    fprintf(stderr, "%s\n", base);
    float avg_loss = -1;
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int batch = net.batch;
    int steps = net.time_steps;
    int i = (*net.seen)/net.batch;
@@ -71,7 +77,7 @@
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;
        printf("%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
        fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time));
        if(i%100==0){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -92,7 +98,7 @@
{
    srand(rseed);
    char *base = basecfg(cfgfile);
    printf("%s\n", base);
    fprintf(stderr, "%s\n", base);
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
@@ -128,6 +134,43 @@
    printf("\n");
}
void valid_char_rnn(char *cfgfile, char *weightfile, char *filename)
{
    FILE *fp = fopen(filename, "r");
    //FILE *fp = fopen("data/ab.txt", "r");
    //FILE *fp = fopen("data/grrm/asoiaf.txt", "r");
    fseek(fp, 0, SEEK_END);
    size_t size = ftell(fp);
    fseek(fp, 0, SEEK_SET);
    char *text = calloc(size, sizeof(char));
    fread(text, 1, size, fp);
    fclose(fp);
    char *base = basecfg(cfgfile);
    fprintf(stderr, "%s\n", base);
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }
    int i;
    char c;
    float *input = calloc(256, sizeof(float));
    float sum = 0;
    for(i = 0; i < size-1; ++i){
        c = text[i];
        input[(int)c] = 1;
        float *out = network_predict(net, input);
        input[(int)c] = 0;
        sum += log(out[(int)text[i+1]]);
    }
    printf("Log Probability: %f\n", sum);
}
void run_char_rnn(int argc, char **argv)
{
    if(argc < 4){
@@ -143,5 +186,6 @@
    char *cfg = argv[3];
    char *weights = (argc > 4) ? argv[4] : 0;
    if(0==strcmp(argv[2], "train")) train_char_rnn(cfg, weights, filename);
    else if(0==strcmp(argv[2], "valid")) valid_char_rnn(cfg, weights, filename);
    else if(0==strcmp(argv[2], "test")) test_char_rnn(cfg, weights, len, seed, temp, rseed);
}