Joseph Redmon
2016-03-16 67794a52a1ca19275f186dbc21cb45c1a45d6b92
src/go.c
@@ -176,7 +176,7 @@
    }
}
void test_go(char *filename, char *weightfile)
void test_go(char *filename, char *weightfile, int multi)
{
    network net = parse_network_cfg(filename);
    if(weightfile){
@@ -191,25 +191,25 @@
        float *output = network_predict(net, board);
        copy_cpu(19*19, output, 1, move, 1);
        int i;
#ifdef GPU
        image bim = float_to_image(19, 19, 1, board);
        for(i = 1; i < 8; ++i){
            rotate_image_cw(bim, i);
            if(i >= 4) flip_image(bim);
        if(multi){
            image bim = float_to_image(19, 19, 1, board);
            for(i = 1; i < 8; ++i){
                rotate_image_cw(bim, i);
                if(i >= 4) flip_image(bim);
            float *output = network_predict(net, board);
            image oim = float_to_image(19, 19, 1, output);
                float *output = network_predict(net, board);
                image oim = float_to_image(19, 19, 1, output);
            if(i >= 4) flip_image(oim);
            rotate_image_cw(oim, -i);
                if(i >= 4) flip_image(oim);
                rotate_image_cw(oim, -i);
            axpy_cpu(19*19, 1, output, 1, move, 1);
                axpy_cpu(19*19, 1, output, 1, move, 1);
            if(i >= 4) flip_image(bim);
            rotate_image_cw(bim, -i);
                if(i >= 4) flip_image(bim);
                rotate_image_cw(bim, -i);
            }
            scal_cpu(19*19, 1./8., move, 1);
        }
        scal_cpu(19*19, 1./8., move, 1);
#endif
        for(i = 0; i < 19*19; ++i){
            if(board[i]) move[i] = 0;
        }
@@ -282,8 +282,9 @@
    char *cfg = argv[3];
    char *weights = (argc > 4) ? argv[4] : 0;
    int multi = find_arg(argc, argv, "-multi");
    if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights);
    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
}