Joseph Redmon
2016-03-16 13d3b038b83a7e59191828304231dd0a6c6f2f9c
src/go.c
@@ -8,8 +8,9 @@
#include "opencv2/highgui/highgui_c.h"
#endif
int inverted = 0;
int noi = 0;
int inverted = 1;
int noi = 1;
static const int nind = 5;
void train_go(char *cfgfile, char *weightfile)
{
@@ -147,12 +148,14 @@
            int index = j*19 + i;
            if(indexes){
                int found = 0;
                for(n = 0; n < 3; ++n){
                for(n = 0; n < nind; ++n){
                    if(index == indexes[n]){
                        found = 1;
                        if(n == 0) printf("\uff11");
                        else if(n == 1) printf("\uff12");
                        else if(n == 2) printf("\uff13");
                        else if(n == 3) printf("\uff14");
                        else if(n == 4) printf("\uff15");
                    }
                }
                if(found) continue;
@@ -173,7 +176,7 @@
    }
}
void test_go(char *filename, char *weightfile)
void test_go(char *filename, char *weightfile, int multi)
{
    network net = parse_network_cfg(filename);
    if(weightfile){
@@ -188,75 +191,79 @@
        float *output = network_predict(net, board);
        copy_cpu(19*19, output, 1, move, 1);
        int i;
#ifdef GPU
        image bim = float_to_image(19, 19, 1, board);
        for(i = 1; i < 8; ++i){
            rotate_image_cw(bim, i);
            if(i >= 4) flip_image(bim);
        if(multi){
            image bim = float_to_image(19, 19, 1, board);
            for(i = 1; i < 8; ++i){
                rotate_image_cw(bim, i);
                if(i >= 4) flip_image(bim);
            float *output = network_predict(net, board);
            image oim = float_to_image(19, 19, 1, output);
                float *output = network_predict(net, board);
                image oim = float_to_image(19, 19, 1, output);
            if(i >= 4) flip_image(oim);
            rotate_image_cw(oim, -i);
                if(i >= 4) flip_image(oim);
                rotate_image_cw(oim, -i);
            axpy_cpu(19*19, 1, output, 1, move, 1);
                axpy_cpu(19*19, 1, output, 1, move, 1);
            if(i >= 4) flip_image(bim);
            rotate_image_cw(bim, -i);
                if(i >= 4) flip_image(bim);
                rotate_image_cw(bim, -i);
            }
            scal_cpu(19*19, 1./8., move, 1);
        }
        scal_cpu(19*19, 1./8., move, 1);
#endif
        for(i = 0; i < 19*19; ++i){
            if(board[i]) move[i] = 0;
        }
        int indexes[3];
        int indexes[nind];
        int row, col;
        top_k(move, 19*19, 3, indexes);
        top_k(move, 19*19, nind, indexes);
        print_board(board, color, indexes);
        for(i = 0; i < 3; ++i){
        for(i = 0; i < nind; ++i){
            int index = indexes[i];
            row = index / 19;
            col = index % 19;
            printf("Suggested: %c %d, %.2f%%\n", col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
            printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
        }
        int index = indexes[0];
        int rec_row = index / 19;
        int rec_col = index % 19;
        if(color == 1) printf("\u25EF Enter move: ");
        else printf("\u25C9 Enter move: ");
        char c;
        char *line = fgetl(stdin);
        int num = sscanf(line, "%c %d", &c, &row);
        if (strlen(line) == 0){
            row = rec_row;
            col = rec_col;
            board[row*19 + col] = 1;
        }else if (c < 'A' || c > 'T'){
            if (c == 'p'){
                flip_board(board);
                color = -color;
                free(line);
                continue;
            }else{
        int picked = 1;
        int dnum = sscanf(line, "%d", &picked);
        int cnum = sscanf(line, "%c", &c);
        if (strlen(line) == 0 || dnum) {
            --picked;
            if (picked < nind){
                int index = indexes[picked];
                row = index / 19;
                col = index % 19;
                board[row*19 + col] = 1;
            }
        } else if (cnum){
            if (c <= 'T' && c >= 'A'){
                int num = sscanf(line, "%c %d", &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 2) board[row*19 + col] = 1;
            } else if (c == 'p') {
                // Pass
            } else if(c=='b' || c == 'w'){
                char g;
                num = sscanf(line, "%c %c %d", &g, &c, &row);
                row = (inverted)?19 - row : row+1;
                int num = sscanf(line, "%c %c %d", &g, &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
            } else if(c == 'c'){
                char g;
                int num = sscanf(line, "%c %c %d", &g, &c, &row);
                row = (inverted)?19 - row : row-1;
                col = c - 'A';
                if (col > 7 && noi) col -= 1;
                if (num == 3) board[row*19 + col] = 0;
            }
        } else if(num == 2){
            row = (inverted)?19 - row : row+1;
            col = c - 'A';
            if (col > 7 && noi) col -= 1;
            board[row*19 + col] = 1;
        }else{
            free(line);
            continue;
        }
        free(line);
        update_board(board);
@@ -275,8 +282,9 @@
    char *cfg = argv[3];
    char *weights = (argc > 4) ? argv[4] : 0;
    int multi = find_arg(argc, argv, "-multi");
    if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights);
    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
}