From ae43c2bc32fbb838bfebeeaf2c2b058ccab5c83c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@burninator.cs.washington.edu>
Date: Thu, 23 Jun 2016 05:31:14 +0000
Subject: [PATCH] hi

---
 src/go.c |  289 ++++++++++++++++++++++++++++++++++++++++++++++++++-------
 1 files changed, 252 insertions(+), 37 deletions(-)

diff --git a/src/go.c b/src/go.c
index 8d0cf52..91beaf1 100644
--- a/src/go.c
+++ b/src/go.c
@@ -98,6 +98,7 @@
         int col = b[1];
         labels[col + 19*(row + i*19)] = 1;
         string_to_board(b+2, boards+i*19*19);
+        boards[col + 19*(row + i*19)] = 0;
 
         int flip = rand()%2;
         int rotate = rand()%4;
@@ -132,6 +133,7 @@
     float *board = calloc(19*19*net.batch, sizeof(float));
     float *move = calloc(19*19*net.batch, sizeof(float));
     moves m = load_go_moves("/home/pjreddie/go.train");
+    //moves m = load_go_moves("games.txt");
 
     int N = m.n;
     int epoch = (*net.seen)/N;
@@ -215,7 +217,7 @@
     }
     fprintf(stream, "\n");
     for(j = 0; j < 19; ++j){
-        fprintf(stream, "%2d ", (inverted) ? 19-j : j+1);
+        fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
         for(i = 0; i < 19; ++i){
             int index = j*19 + i;
             if(indexes){
@@ -223,17 +225,26 @@
                 for(n = 0; n < nind; ++n){
                     if(index == indexes[n]){
                         found = 1;
+                        /*
                         if(n == 0) fprintf(stream, "\uff11");
                         else if(n == 1) fprintf(stream, "\uff12");
                         else if(n == 2) fprintf(stream, "\uff13");
                         else if(n == 3) fprintf(stream, "\uff14");
                         else if(n == 4) fprintf(stream, "\uff15");
+                        */
+                        if(n == 0) fprintf(stream, " 1");
+                        else if(n == 1) fprintf(stream, " 2");
+                        else if(n == 2) fprintf(stream, " 3");
+                        else if(n == 3) fprintf(stream, " 4");
+                        else if(n == 4) fprintf(stream, " 5");
                     }
                 }
                 if(found) continue;
             }
-            if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
-            else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+            //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
+            //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+            if(board[index]*-swap > 0) fprintf(stream, " O");
+            else if(board[index]*-swap < 0) fprintf(stream, " X");
             else fprintf(stream, "  ");
         }
         fprintf(stream, "\n");
@@ -337,6 +348,90 @@
     return 1;
 }
 
+int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
+{
+    int i, j;
+    for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+
+    float move[361];
+    if (player < 0) flip_board(board);
+    predict_move(net, board, move, multi);
+    if (player < 0) flip_board(board);
+
+    
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
+        }
+    }
+
+    int indexes[nind];
+    top_k(move, 19*19, nind, indexes);
+    if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
+
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
+        }
+    }
+
+
+    int max = max_index(move, 19*19);
+    int row = max / 19;
+    int col = max % 19;
+    int index = sample_array(move, 19*19);
+
+    if(print){
+        top_k(move, 19*19, nind, indexes);
+        for(i = 0; i < nind; ++i){
+            if (!move[indexes[i]]) indexes[i] = -1;
+        }
+        print_board(board, player, indexes);
+        for(i = 0; i < nind; ++i){
+            fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
+        }
+    }
+
+    if(suicide_go(board, player, row, col)){
+        return -1; 
+    }
+    if(suicide_go(board, player, index/19, index%19)) index = max;
+    return index;
+}
+
+void valid_go(char *cfgfile, char *weightfile, int multi)
+{
+    data_seed = time(0);
+    srand(time(0));
+    char *base = basecfg(cfgfile);
+    printf("%s\n", base);
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    set_batch_network(&net, 1);
+    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+
+    float *board = calloc(19*19, sizeof(float));
+    float *move = calloc(19*19, sizeof(float));
+    moves m = load_go_moves("/home/pjreddie/backup/go.test");
+
+    int N = m.n;
+    int i;
+    int correct = 0;
+    for(i = 0; i <N; ++i){
+        char *b = m.data[i];
+        int row = b[0];
+        int col = b[1];
+        int truth = col + 19*row;
+        string_to_board(b+2, board);
+        predict_move(net, board, move, multi);
+        int index = max_index(move, 19*19);
+        if(index == truth) ++correct;
+        printf("%d Accuracy %f\n", i, (float) correct/(i+1));
+    }
+}
+
 void engine_go(char *filename, char *weightfile, int multi)
 {
     network net = parse_network_cfg(filename);
@@ -346,12 +441,10 @@
     srand(time(0));
     set_batch_network(&net, 1);
     float *board = calloc(19*19, sizeof(float));
-    float *move = calloc(19*19, sizeof(float));
     char *one = calloc(91, sizeof(char));
     char *two = calloc(91, sizeof(char));
     int passed = 0;
     while(1){
-        print_board(board, 1, 0);
         char buff[256];
         int id = 0;
         int has_id = (scanf("%d", &id) == 1);
@@ -436,42 +529,34 @@
             board_to_string(one, board);
 
             printf("=%s \n\n", ids);
+            print_board(board, 1, 0);
         } else if (!strcmp(buff, "genmove")){
             char color[256];
             scanf("%s", color);
             int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
 
-            if(player < 0) flip_board(board);
-            predict_move(net, board, move, multi);
-            if(player < 0) flip_board(board);
-
-            int i, j;
-            for(i = 0; i < 19; ++i){
-                for(j = 0; j < 19; ++j){
-                    if (!legal_go(board, two, player, i, j)) move[i*19 + j] = 0;
-                }
-            }
-            int index = max_index(move, 19*19);
-            int row = index / 19;
-            char col = index % 19;
-
-            char *swap = two;
-            two = one;
-            one = swap;
-
-            if(passed || suicide_go(board, player, row, col)){
+            int index = generate_move(net, player, board, multi, .1, .7, two, 1);
+            if(passed || index < 0){
                 printf("=%s pass\n\n", ids);
                 passed = 0;
             } else {
+                int row = index / 19;
+                int col = index % 19;
+
+                char *swap = two;
+                two = one;
+                one = swap;
+
                 move_go(board, player, row, col);
                 board_to_string(one, board);
-
                 row = 19 - row;
                 if (col >= 8) ++col;
                 printf("=%s %c%d\n\n", ids, 'A' + col, row);
+                print_board(board, 1, 0);
             }
+
         } else if (!strcmp(buff, "p")){
-            print_board(board, 1, 0);
+            //print_board(board, 1, 0);
         } else if (!strcmp(buff, "final_status_list")){
             char type[256];
             scanf("%s", type);
@@ -479,7 +564,30 @@
             char *line = fgetl(stdin);
             free(line);
             if(type[0] == 'd' || type[0] == 'D'){
-                printf("=%s \n\n", ids);
+                FILE *f = fopen("game.txt", "w");
+                int i, j;
+                int count = 2;
+                fprintf(f, "boardsize 19\n");
+                fprintf(f, "clear_board\n");
+                for(j = 0; j < 19; ++j){
+                    for(i = 0; i < 19; ++i){
+                        if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i]) ++count;
+                    }
+                }
+                fprintf(f, "final_status_list dead\n");
+                fclose(f);
+                FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+                for(i = 0; i < count; ++i){
+                    free(fgetl(p));
+                    free(fgetl(p));
+                }
+                char *l = 0;
+                while((l = fgetl(p))){
+                    printf("%s\n", l);
+                    free(l);
+                }
             } else {
                 printf("?%s unknown command\n\n", ids);
             }
@@ -541,8 +649,10 @@
             col = index % 19;
             printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
         }
-        if(color == 1) printf("\u25EF Enter move: ");
-        else printf("\u25C9 Enter move: ");
+        //if(color == 1) printf("\u25EF Enter move: ");
+        //else printf("\u25C9 Enter move: ");
+        if(color == 1) printf("X Enter move: ");
+        else printf("O Enter move: ");
 
         char c;
         char *line = fgetl(stdin);
@@ -588,17 +698,118 @@
     }
 }
 
-void boards_go()
+float score_game(float *board)
 {
-    moves m = load_go_moves("/home/pjreddie/go.train");
-    int i;
-    float board[361];
-    for(i = 0; i < 10; ++i){
-        printf("%d %d\n", m.data[i][0], m.data[i][1]);
-        string_to_board(m.data[i]+2, board);
-        print_board(board, 1, 0);
+    FILE *f = fopen("game.txt", "w");
+    int i, j;
+    int count = 3;
+    fprintf(f, "komi 6.5\n");
+    fprintf(f, "boardsize 19\n");
+    fprintf(f, "clear_board\n");
+    for(j = 0; j < 19; ++j){
+        for(i = 0; i < 19; ++i){
+            if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i]) ++count;
+        }
+    }
+    fprintf(f, "final_score\n");
+    fclose(f);
+    FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+    for(i = 0; i < count; ++i){
+        free(fgetl(p));
+        free(fgetl(p));
+    }
+    char *l = 0;
+    float score = 0;
+    char player = 0;
+    while((l = fgetl(p))){
+        fprintf(stderr, "%s  \t", l);
+        int n = sscanf(l, "= %c+%f", &player, &score);
+        free(l);
+        if (n == 2) break;
+    }
+    if(player == 'W') score = -score;
+    pclose(p);
+    return score;
+}
+
+void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
+{
+    network net = parse_network_cfg(filename);
+    if(weightfile){
+        load_weights(&net, weightfile);
     }
 
+    network net2 = net;
+    if(f2){
+        net2 = parse_network_cfg(f2);
+        if(w2){
+            load_weights(&net2, w2);
+        }
+    }
+    srand(time(0));
+    char boards[300][93];
+    int count = 0;
+    set_batch_network(&net, 1);
+    set_batch_network(&net2, 1);
+    float *board = calloc(19*19, sizeof(float));
+    char *one = calloc(91, sizeof(char));
+    char *two = calloc(91, sizeof(char));
+    int done = 0;
+    int player = 1;
+    int p1 = 0;
+    int p2 = 0;
+    int total = 0;
+    while(1){
+        if (done || count >= 300){
+            float score = score_game(board);
+            int i = (score > 0)? 0 : 1;
+            if((score > 0) == (total%2==0)) ++p1;
+            else ++p2;
+            ++total;
+            fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
+            int j;
+            for(; i < count; i += 2){
+                for(j = 0; j < 93; ++j){
+                    printf("%c", boards[i][j]);
+                }
+                printf("\n");
+            }
+            memset(board, 0, 19*19*sizeof(float));
+            player = 1;
+            done = 0;
+            count = 0;
+            fflush(stdout);
+            fflush(stderr);
+        }
+        //print_board(board, 1, 0);
+        //sleep(1);
+        network use = ((total%2==0) == (player==1)) ? net : net2;
+        int index = generate_move(use, player, board, multi, .1, .7, two, 0);
+        if(index < 0){
+            done = 1;
+            continue;
+        }
+        int row = index / 19;
+        int col = index % 19;
+
+        char *swap = two;
+        two = one;
+        one = swap;
+
+        if(player < 0) flip_board(board);
+        boards[count][0] = row;
+        boards[count][1] = col;
+        board_to_string(boards[count] + 2, board);
+        if(player < 0) flip_board(board);
+        ++count;
+
+        move_go(board, player, row, col);
+        board_to_string(one, board);
+
+        player = -player;
+    }
 }
 
 void run_go(int argc, char **argv)
@@ -611,8 +822,12 @@
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
+    char *c2 = (argc > 5) ? argv[5] : 0;
+    char *w2 = (argc > 6) ? argv[6] : 0;
     int multi = find_arg(argc, argv, "-multi");
     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
+    else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
+    else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
     else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
     else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
 }

--
Gitblit v1.10.0