From 028696bf15efeca3acb3db8c42a96f7b9e0f55ff Mon Sep 17 00:00:00 2001
From: iovodov <b@ovdv.ru>
Date: Thu, 03 May 2018 13:33:46 +0000
Subject: [PATCH] Output improvements for detector results: When printing detector results, output was done in random order, obfuscating results for interpreting. Now: 1. Text output includes coordinates of rects in (left,right,top,bottom in pixels) along with label and score 2. Text output is sorted by rect lefts to simplify finding appropriate rects on image 3. If several class probs are > thresh for some detection, the most probable is written first and coordinates for others are not repeated 4. Rects are imprinted in image in order by their best class prob, so most probable rects are always on top and not overlayed by less probable ones 5. Most probable label for rect is always written first Also: 6. Message about low GPU memory include required amount

---
 src/go.c |  818 ++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 706 insertions(+), 112 deletions(-)

diff --git a/src/go.c b/src/go.c
index 53f40ca..223d136 100644
--- a/src/go.c
+++ b/src/go.c
@@ -8,9 +8,115 @@
 #include "opencv2/highgui/highgui_c.h"
 #endif
 
+int inverted = 1;
+int noi = 1;
+//static const int nind = 5;
+#define nind 5
+
+typedef struct {
+    char **data;
+    int n;
+} moves;
+
+char *fgetgo(FILE *fp)
+{
+    if(feof(fp)) return 0;
+    size_t size = 94;
+    char *line = malloc(size*sizeof(char));
+    if(size != fread(line, sizeof(char), size, fp)){
+        free(line);
+        return 0;
+    }
+
+    return line;
+}
+
+moves load_go_moves(char *filename)
+{
+    moves m;
+    m.n = 128;
+    m.data = calloc(128, sizeof(char*));
+    FILE *fp = fopen(filename, "rb");
+    int count = 0;
+    char *line = 0;
+    while((line = fgetgo(fp))){
+        if(count >= m.n){
+            m.n *= 2;
+            m.data = realloc(m.data, m.n*sizeof(char*));
+        }
+        m.data[count] = line;
+        ++count;
+    }
+    printf("%d\n", count);
+    m.n = count;
+    m.data = realloc(m.data, count*sizeof(char*));
+    return m;
+}
+
+void string_to_board(char *s, float *board)
+{
+    int i, j;
+    //memset(board, 0, 1*19*19*sizeof(float));
+    int count = 0;
+    for(i = 0; i < 91; ++i){
+        char c = s[i];
+        for(j = 0; j < 4; ++j){
+            int me = (c >> (2*j)) & 1;
+            int you = (c >> (2*j + 1)) & 1;
+            if (me) board[count] = 1;
+            else if (you) board[count] = -1;
+            else board[count] = 0;
+            ++count;
+            if(count >= 19*19) break;
+        }
+    }
+}
+
+void board_to_string(char *s, float *board)
+{
+    int i, j;
+    memset(s, 0, (19*19/4+1)*sizeof(char));
+    int count = 0;
+    for(i = 0; i < 91; ++i){
+        for(j = 0; j < 4; ++j){
+            int me = (board[count] == 1);
+            int you = (board[count] == -1);
+            if (me) s[i] = s[i] | (1<<(2*j));
+            if (you) s[i] = s[i] | (1<<(2*j + 1));
+            ++count;
+            if(count >= 19*19) break;
+        }
+    }
+}
+
+void random_go_moves(moves m, float *boards, float *labels, int n)
+{
+    int i;
+    memset(labels, 0, 19*19*n*sizeof(float));
+    for(i = 0; i < n; ++i){
+        char *b = m.data[rand()%m.n];
+        int row = b[0];
+        int col = b[1];
+        labels[col + 19*(row + i*19)] = 1;
+        string_to_board(b+2, boards+i*19*19);
+        boards[col + 19*(row + i*19)] = 0;
+
+        int flip = rand()%2;
+        int rotate = rand()%4;
+        image in = float_to_image(19, 19, 1, boards+i*19*19);
+        image out = float_to_image(19, 19, 1, labels+i*19*19);
+        if(flip){
+            flip_image(in);
+            flip_image(out);
+        }
+        rotate_image_cw(in, rotate);
+        rotate_image_cw(out, rotate);
+    }
+}
+
+
 void train_go(char *cfgfile, char *weightfile)
 {
-    data_seed = time(0);
     srand(time(0));
     float avg_loss = -1;
     char *base = basecfg(cfgfile);
@@ -23,72 +129,62 @@
 
     char *backup_directory = "/home/pjreddie/backup/";
 
-    data train = load_go("/home/pjreddie/backup/go.train");
-    int N = train.X.rows;
+    char buff[256];
+    float *board = calloc(19*19*net.batch, sizeof(float));
+    float *move = calloc(19*19*net.batch, sizeof(float));
+    moves m = load_go_moves("/home/pjreddie/backup/go.train");
+    //moves m = load_go_moves("games.txt");
+
+    int N = m.n;
     int epoch = (*net.seen)/N;
     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
         clock_t time=clock();
 
-        data batch = get_random_data(train, net.batch);
-        int i;
-        for(i = 0; i < batch.X.rows; ++i){
-            int flip = rand()%2;
-            int rotate = rand()%4;
-            image in = float_to_image(19, 19, 1, batch.X.vals[i]);
-            image out = float_to_image(19, 19, 1, batch.y.vals[i]);
-            //show_image_normalized(in, "in");
-            //show_image_normalized(out, "out");
-            if(flip){
-                flip_image(in);
-                flip_image(out);
-            }
-            rotate_image_cw(in, rotate);
-            rotate_image_cw(out, rotate);
-            //show_image_normalized(in, "in2");
-            //show_image_normalized(out, "out2");
-            //cvWaitKey(0);
-        }
-        float loss = train_network(net, batch);
-        free_data(batch);
+        random_go_moves(m, board, move, net.batch);
+        float loss = train_network_datum(net, board, move) / net.batch;
         if(avg_loss == -1) avg_loss = loss;
         avg_loss = avg_loss*.95 + loss*.05;
         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
         if(*net.seen/N > epoch){
             epoch = *net.seen/N;
             char buff[256];
-            sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
+            sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
             save_weights(net, buff);
+
         }
         if(get_current_batch(net)%100 == 0){
             char buff[256];
             sprintf(buff, "%s/%s.backup",backup_directory,base);
             save_weights(net, buff);
         }
+        if(get_current_batch(net)%10000 == 0){
+            char buff[256];
+            sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
+            save_weights(net, buff);
+        }
     }
-    char buff[256];
     sprintf(buff, "%s/%s.weights", backup_directory, base);
     save_weights(net, buff);
 
     free_network(net);
     free(base);
-    free_data(train);
 }
 
-void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int num, int side)
+void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
 {
-    if (!num) return;
     if (row < 0 || row > 18 || col < 0 || col > 18) return;
     int index = row*19 + col;
     if (board[index] != side) return;
     if (visited[index]) return;
     visited[index] = 1;
-    lib[index] += num;
-    propagate_liberty(board, lib, visited, row+1, col, num, side);
-    propagate_liberty(board, lib, visited, row-1, col, num, side);
-    propagate_liberty(board, lib, visited, row, col+1, num, side);
-    propagate_liberty(board, lib, visited, row, col-1, num, side);
+    lib[index] += 1;
+    propagate_liberty(board, lib, visited, row+1, col, side);
+    propagate_liberty(board, lib, visited, row-1, col, side);
+    propagate_liberty(board, lib, visited, row, col+1, side);
+    propagate_liberty(board, lib, visited, row, col-1, side);
 }
 
+
 int *calculate_liberties(float *board)
 {
     int *lib = calloc(19*19, sizeof(int));
@@ -98,48 +194,60 @@
         for(i = 0; i < 19; ++i){
             memset(visited, 0, 19*19*sizeof(int));
             int index = j*19 + i;
-            if(board[index]){
-                int side = board[index];
-                int num = 0;
-                if (i > 0  && board[j*19 + i -  1] == 0) ++num;
-                if (i < 18 && board[j*19 + i +  1] == 0) ++num;
-                if (j > 0  && board[j*19 + i - 19] == 0) ++num;
-                if (j < 18 && board[j*19 + i + 19] == 0) ++num;
-                propagate_liberty(board, lib, visited, j, i, num, side);
+            if(board[index] == 0){
+                if ((i > 0)  && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
+                if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
+                if ((j > 0)  && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
+                if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
             }
         }
     }
     return lib;
 }
 
-void update_board(float *board)
+void print_board(float *board, int swap, int *indexes)
 {
-    int i;
-    int *l = calculate_liberties(board);
-    for(i = 0; i < 19*19; ++i){
-        if (board[i] < 0 && !l[i]) board[i] = 0;
-    }
-    free(l);
-}
-
-void print_board(float *board)
-{
-    int i,j;
-    printf("\n\n");
-    printf("   ");
+    //FILE *stream = stdout;
+    FILE *stream = stderr;
+    int i,j,n;
+    fprintf(stream, "\n\n");
+    fprintf(stream, "   ");
     for(i = 0; i < 19; ++i){
-        printf("%c ", 'A' + i + 1*(i > 7));
+        fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
     }
-    printf("\n");
+    fprintf(stream, "\n");
     for(j = 0; j < 19; ++j){
-        printf("%2d ", 19-j);
+        fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
         for(i = 0; i < 19; ++i){
             int index = j*19 + i;
-            if(board[index] > 0) printf("\u25C9 ");
-            else if(board[index] < 0) printf("\u25EF ");
-            else printf("  ");
+            if(indexes){
+                int found = 0;
+                for(n = 0; n < nind; ++n){
+                    if(index == indexes[n]){
+                        found = 1;
+                        /*
+                        if(n == 0) fprintf(stream, "\uff11");
+                        else if(n == 1) fprintf(stream, "\uff12");
+                        else if(n == 2) fprintf(stream, "\uff13");
+                        else if(n == 3) fprintf(stream, "\uff14");
+                        else if(n == 4) fprintf(stream, "\uff15");
+                        */
+                        if(n == 0) fprintf(stream, " 1");
+                        else if(n == 1) fprintf(stream, " 2");
+                        else if(n == 2) fprintf(stream, " 3");
+                        else if(n == 3) fprintf(stream, " 4");
+                        else if(n == 4) fprintf(stream, " 5");
+                    }
+                }
+                if(found) continue;
+            }
+            //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
+            //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+            if(board[index]*-swap > 0) fprintf(stream, " O");
+            else if(board[index]*-swap < 0) fprintf(stream, " X");
+            else fprintf(stream, "  ");
         }
-        printf("\n");
+        fprintf(stream, "\n");
     }
 }
 
@@ -151,21 +259,12 @@
     }
 }
 
-void test_go(char *filename, char *weightfile)
+void predict_move(network net, float *board, float *move, int multi)
 {
-    network net = parse_network_cfg(filename);
-    if(weightfile){
-        load_weights(&net, weightfile);
-    }
-    srand(time(0));
-    set_batch_network(&net, 1);
-    float *board = calloc(19*19, sizeof(float));
-    float *move = calloc(19*19, sizeof(float));
-    while(1){
-        float *output = network_predict(net, board);
-        copy_cpu(19*19, output, 1, move, 1);
-        int i;
-        #ifdef GPU
+    float *output = network_predict(net, board);
+    copy_cpu(19*19, output, 1, move, 1);
+    int i;
+    if(multi){
         image bim = float_to_image(19, 19, 1, board);
         for(i = 1; i < 8; ++i){
             rotate_image_cw(bim, i);
@@ -183,61 +282,550 @@
             rotate_image_cw(bim, -i);
         }
         scal_cpu(19*19, 1./8., move, 1);
-        #endif
+    }
+    for(i = 0; i < 19*19; ++i){
+        if(board[i]) move[i] = 0;
+    }
+}
+
+void remove_connected(float *b, int *lib, int p, int r, int c)
+{
+    if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
+    if (b[r*19 + c] != p) return;
+    if (lib[r*19 + c] != 1) return;
+    b[r*19 + c] = 0;
+    remove_connected(b, lib, p, r+1, c);
+    remove_connected(b, lib, p, r-1, c);
+    remove_connected(b, lib, p, r, c+1);
+    remove_connected(b, lib, p, r, c-1);
+}
+
+
+void move_go(float *b, int p, int r, int c)
+{
+    int *l = calculate_liberties(b);
+    b[r*19 + c] = p;
+    remove_connected(b, l, -p, r+1, c);
+    remove_connected(b, l, -p, r-1, c);
+    remove_connected(b, l, -p, r, c+1);
+    remove_connected(b, l, -p, r, c-1);
+    free(l);
+}
+
+int makes_safe_go(float *b, int *lib, int p, int r, int c){
+    if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
+    if (b[r*19 + c] == -p){
+        if (lib[r*19 + c] > 1) return 0;
+        else return 1;
+    }
+    if (b[r*19 + c] == 0) return 1;
+    if (lib[r*19 + c] > 1) return 1;
+    return 0;
+}
+
+int suicide_go(float *b, int p, int r, int c)
+{
+    int *l = calculate_liberties(b);
+    int safe = 0;
+    safe = safe || makes_safe_go(b, l, p, r+1, c);
+    safe = safe || makes_safe_go(b, l, p, r-1, c);
+    safe = safe || makes_safe_go(b, l, p, r, c+1);
+    safe = safe || makes_safe_go(b, l, p, r, c-1);
+    free(l);
+    return !safe;
+}
+
+int legal_go(float *b, char *ko, int p, int r, int c)
+{
+    if (b[r*19 + c]) return 0;
+    char curr[91];
+    char next[91];
+    board_to_string(curr, b);
+    move_go(b, p, r, c);
+    board_to_string(next, b);
+    string_to_board(curr, b);
+    if(memcmp(next, ko, 91) == 0) return 0;
+    return 1;
+}
+
+int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
+{
+    int i, j;
+    for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+
+    float move[361];
+    if (player < 0) flip_board(board);
+    predict_move(net, board, move, multi);
+    if (player < 0) flip_board(board);
+
+    
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
+        }
+    }
+
+    int indexes[nind];
+    top_k(move, 19*19, nind, indexes);
+    if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
+
+    for(i = 0; i < 19; ++i){
+        for(j = 0; j < 19; ++j){
+            if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
+        }
+    }
+
+
+    int max = max_index(move, 19*19);
+    int row = max / 19;
+    int col = max % 19;
+    int index = sample_array(move, 19*19);
+
+    if(print){
+        top_k(move, 19*19, nind, indexes);
+        for(i = 0; i < nind; ++i){
+            if (!move[indexes[i]]) indexes[i] = -1;
+        }
+        print_board(board, player, indexes);
+        for(i = 0; i < nind; ++i){
+            fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
+        }
+    }
+
+    if(suicide_go(board, player, row, col)){
+        return -1; 
+    }
+    if(suicide_go(board, player, index/19, index%19)) index = max;
+    return index;
+}
+
+void valid_go(char *cfgfile, char *weightfile, int multi)
+{
+    srand(time(0));
+    char *base = basecfg(cfgfile);
+    printf("%s\n", base);
+    network net = parse_network_cfg(cfgfile);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    set_batch_network(&net, 1);
+    printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+
+    float *board = calloc(19*19, sizeof(float));
+    float *move = calloc(19*19, sizeof(float));
+    moves m = load_go_moves("/home/pjreddie/backup/go.test");
+
+    int N = m.n;
+    int i;
+    int correct = 0;
+    for(i = 0; i <N; ++i){
+        char *b = m.data[i];
+        int row = b[0];
+        int col = b[1];
+        int truth = col + 19*row;
+        string_to_board(b+2, board);
+        predict_move(net, board, move, multi);
+        int index = max_index(move, 19*19);
+        if(index == truth) ++correct;
+        printf("%d Accuracy %f\n", i, (float) correct/(i+1));
+    }
+}
+
+void engine_go(char *filename, char *weightfile, int multi)
+{
+    network net = parse_network_cfg(filename);
+    if(weightfile){
+        load_weights(&net, weightfile);
+    }
+    srand(time(0));
+    set_batch_network(&net, 1);
+    float *board = calloc(19*19, sizeof(float));
+    char *one = calloc(91, sizeof(char));
+    char *two = calloc(91, sizeof(char));
+    int passed = 0;
+    while(1){
+        char buff[256];
+        int id = 0;
+        int has_id = (scanf("%d", &id) == 1);
+        scanf("%s", buff);
+        if (feof(stdin)) break;
+        char ids[256];
+        sprintf(ids, "%d", id);
+        //fprintf(stderr, "%s\n", buff);
+        if (!has_id) ids[0] = 0;
+        if (!strcmp(buff, "protocol_version")){
+            printf("=%s 2\n\n", ids);
+        } else if (!strcmp(buff, "name")){
+            printf("=%s DarkGo\n\n", ids);
+        } else if (!strcmp(buff, "version")){
+            printf("=%s 1.0\n\n", ids);
+        } else if (!strcmp(buff, "known_command")){
+            char comm[256];
+            scanf("%s", comm);
+            int known = (!strcmp(comm, "protocol_version") || 
+                    !strcmp(comm, "name") || 
+                    !strcmp(comm, "version") || 
+                    !strcmp(comm, "known_command") || 
+                    !strcmp(comm, "list_commands") || 
+                    !strcmp(comm, "quit") || 
+                    !strcmp(comm, "boardsize") || 
+                    !strcmp(comm, "clear_board") || 
+                    !strcmp(comm, "komi") || 
+                    !strcmp(comm, "final_status_list") || 
+                    !strcmp(comm, "play") || 
+                    !strcmp(comm, "genmove"));
+            if(known) printf("=%s true\n\n", ids);
+            else printf("=%s false\n\n", ids);
+        } else if (!strcmp(buff, "list_commands")){
+            printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
+        } else if (!strcmp(buff, "quit")){
+            break;
+        } else if (!strcmp(buff, "boardsize")){
+            int boardsize = 0;
+            scanf("%d", &boardsize);
+            //fprintf(stderr, "%d\n", boardsize);
+            if(boardsize != 19){
+                printf("?%s unacceptable size\n\n", ids);
+            } else {
+                printf("=%s \n\n", ids);
+            }
+        } else if (!strcmp(buff, "clear_board")){
+            passed = 0;
+            memset(board, 0, 19*19*sizeof(float));
+            printf("=%s \n\n", ids);
+        } else if (!strcmp(buff, "komi")){
+            float komi = 0;
+            scanf("%f", &komi);
+            printf("=%s \n\n", ids);
+        } else if (!strcmp(buff, "play")){
+            char color[256];
+            scanf("%s ", color);
+            char c;
+            int r;
+            int count = scanf("%c%d", &c, &r);
+            int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
+            if(c == 'p' && count < 2) {
+                passed = 1;
+                printf("=%s \n\n", ids);
+                char *line = fgetl(stdin);
+                free(line);
+                fflush(stdout);
+                fflush(stderr);
+                continue;
+            } else {
+                passed = 0;
+            }
+            if(c >= 'A' && c <= 'Z') c = c - 'A';
+            if(c >= 'a' && c <= 'z') c = c - 'a';
+            if(c >= 8) --c;
+            r = 19 - r;
+            fprintf(stderr, "move: %d %d\n", r, c);
+
+            char *swap = two;
+            two = one;
+            one = swap;
+            move_go(board, player, r, c);
+            board_to_string(one, board);
+
+            printf("=%s \n\n", ids);
+            print_board(board, 1, 0);
+        } else if (!strcmp(buff, "genmove")){
+            char color[256];
+            scanf("%s", color);
+            int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
+
+            int index = generate_move(net, player, board, multi, .1, .7, two, 1);
+            if(passed || index < 0){
+                printf("=%s pass\n\n", ids);
+                passed = 0;
+            } else {
+                int row = index / 19;
+                int col = index % 19;
+
+                char *swap = two;
+                two = one;
+                one = swap;
+
+                move_go(board, player, row, col);
+                board_to_string(one, board);
+                row = 19 - row;
+                if (col >= 8) ++col;
+                printf("=%s %c%d\n\n", ids, 'A' + col, row);
+                print_board(board, 1, 0);
+            }
+
+        } else if (!strcmp(buff, "p")){
+            //print_board(board, 1, 0);
+        } else if (!strcmp(buff, "final_status_list")){
+            char type[256];
+            scanf("%s", type);
+            fprintf(stderr, "final_status\n");
+            char *line = fgetl(stdin);
+            free(line);
+            if(type[0] == 'd' || type[0] == 'D'){
+                FILE *f = fopen("game.txt", "w");
+                int i, j;
+                int count = 2;
+                fprintf(f, "boardsize 19\n");
+                fprintf(f, "clear_board\n");
+                for(j = 0; j < 19; ++j){
+                    for(i = 0; i < 19; ++i){
+                        if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+                        if(board[j*19 + i]) ++count;
+                    }
+                }
+                fprintf(f, "final_status_list dead\n");
+                fclose(f);                
+#ifdef _WIN32
+				FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
+#else
+				FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+#endif
+                for(i = 0; i < count; ++i){
+                    free(fgetl(p));
+                    free(fgetl(p));
+                }
+                char *l = 0;
+                while((l = fgetl(p))){
+                    printf("%s\n", l);
+                    free(l);
+                }
+            } else {
+                printf("?%s unknown command\n\n", ids);
+            }
+        } else {
+            char *line = fgetl(stdin);
+            free(line);
+            printf("?%s unknown command\n\n", ids);
+        }
+        fflush(stdout);
+        fflush(stderr);
+    }
+}
+
+void test_go(char *cfg, char *weights, int multi)
+{
+    network net = parse_network_cfg(cfg);
+    if(weights){
+        load_weights(&net, weights);
+    }
+    srand(time(0));
+    set_batch_network(&net, 1);
+    float *board = calloc(19*19, sizeof(float));
+    float *move = calloc(19*19, sizeof(float));
+    int color = 1;
+    while(1){
+        float *output = network_predict(net, board);
+        copy_cpu(19*19, output, 1, move, 1);
+        int i;
+        if(multi){
+            image bim = float_to_image(19, 19, 1, board);
+            for(i = 1; i < 8; ++i){
+                rotate_image_cw(bim, i);
+                if(i >= 4) flip_image(bim);
+
+                float *output = network_predict(net, board);
+                image oim = float_to_image(19, 19, 1, output);
+
+                if(i >= 4) flip_image(oim);
+                rotate_image_cw(oim, -i);
+
+                axpy_cpu(19*19, 1, output, 1, move, 1);
+
+                if(i >= 4) flip_image(bim);
+                rotate_image_cw(bim, -i);
+            }
+            scal_cpu(19*19, 1./8., move, 1);
+        }
         for(i = 0; i < 19*19; ++i){
             if(board[i]) move[i] = 0;
         }
 
-        int indexes[3];
+        int indexes[nind];
         int row, col;
-        top_k(move, 19*19, 3, indexes);
-        print_board(board);
-        for(i = 0; i < 3; ++i){
+        top_k(move, 19*19, nind, indexes);
+        print_board(board, color, indexes);
+        for(i = 0; i < nind; ++i){
             int index = indexes[i];
             row = index / 19;
             col = index % 19;
-            printf("Suggested: %c %d, %.2f%%\n", col + 'A' + 1*(col > 7), 19 - row, move[index]*100);
+            printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
         }
-        int index = indexes[0];
-        int rec_row = index / 19;
-        int rec_col = index % 19;
+        //if(color == 1) printf("\u25EF Enter move: ");
+        //else printf("\u25C9 Enter move: ");
+        if(color == 1) printf("X Enter move: ");
+        else printf("O Enter move: ");
 
-        printf("\u25C9 Enter move: ");
         char c;
         char *line = fgetl(stdin);
-        int num = sscanf(line, "%c %d", &c, &row);
-        if (strlen(line) == 0){
-            row = rec_row;
-            col = rec_col;
-            board[row*19 + col] = 1;
-        }else if (c < 'A' || c > 'T'){
-            if (c == 'p'){
-                flip_board(board);
-                continue;
-            }else{
-                char g;
-                num = sscanf(line, "%c %c %d", &g, &c, &row);
-                row = 19 - row;
-                col = c - 'A';
-                if (col > 7) col -= 1;
-                if (num == 2) board[row*19 + col] = 0;
+        int picked = 1;
+        int dnum = sscanf(line, "%d", &picked);
+        int cnum = sscanf(line, "%c", &c);
+        if (strlen(line) == 0 || dnum) {
+            --picked;
+            if (picked < nind){
+                int index = indexes[picked];
+                row = index / 19;
+                col = index % 19;
+                board[row*19 + col] = 1;
             }
-        } else if(num == 2){
-            row = 19 - row;
-            col = c - 'A';
-            if (col > 7) col -= 1;
-            board[row*19 + col] = 1;
-        }else{
-            continue;
+        } else if (cnum){
+            if (c <= 'T' && c >= 'A'){
+                int num = sscanf(line, "%c %d", &c, &row);
+                row = (inverted)?19 - row : row-1;
+                col = c - 'A';
+                if (col > 7 && noi) col -= 1;
+                if (num == 2) board[row*19 + col] = 1;
+            } else if (c == 'p') {
+                // Pass
+            } else if(c=='b' || c == 'w'){
+                char g;
+                int num = sscanf(line, "%c %c %d", &g, &c, &row);
+                row = (inverted)?19 - row : row-1;
+                col = c - 'A';
+                if (col > 7 && noi) col -= 1;
+                if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color;
+            } else if(c == 'c'){
+                char g;
+                int num = sscanf(line, "%c %c %d", &g, &c, &row);
+                row = (inverted)?19 - row : row-1;
+                col = c - 'A';
+                if (col > 7 && noi) col -= 1;
+                if (num == 3) board[row*19 + col] = 0;
+            }
         }
-        update_board(board);
+        free(line);
         flip_board(board);
+        color = -color;
+    }
+}
+
+float score_game(float *board)
+{
+    FILE *f = fopen("game.txt", "w");
+    int i, j;
+    int count = 3;
+    fprintf(f, "komi 6.5\n");
+    fprintf(f, "boardsize 19\n");
+    fprintf(f, "clear_board\n");
+    for(j = 0; j < 19; ++j){
+        for(i = 0; i < 19; ++i){
+            if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+            if(board[j*19 + i]) ++count;
+        }
+    }
+    fprintf(f, "final_score\n");
+    fclose(f);  
+#ifdef _WIN32
+	FILE *p = _popen("./gnugo --mode gtp < game.txt", "r");
+#else
+	FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+#endif
+    for(i = 0; i < count; ++i){
+        free(fgetl(p));
+        free(fgetl(p));
+    }
+    char *l = 0;
+    float score = 0;
+    char player = 0;
+    while((l = fgetl(p))){
+        fprintf(stderr, "%s  \t", l);
+        int n = sscanf(l, "= %c+%f", &player, &score);
+        free(l);
+        if (n == 2) break;
+    }
+    if(player == 'W') score = -score;
+#ifdef _WIN32
+	_pclose(p);
+#else
+	pclose(p);
+#endif
+    return score;
+}
+
+void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
+{
+    network net = parse_network_cfg(filename);
+    if(weightfile){
+        load_weights(&net, weightfile);
     }
 
+    network net2 = net;
+    if(f2){
+        net2 = parse_network_cfg(f2);
+        if(w2){
+            load_weights(&net2, w2);
+        }
+    }
+    srand(time(0));
+    char boards[300][93];
+    int count = 0;
+    set_batch_network(&net, 1);
+    set_batch_network(&net2, 1);
+    float *board = calloc(19*19, sizeof(float));
+    char *one = calloc(91, sizeof(char));
+    char *two = calloc(91, sizeof(char));
+    int done = 0;
+    int player = 1;
+    int p1 = 0;
+    int p2 = 0;
+    int total = 0;
+    while(1){
+        if (done || count >= 300){
+            float score = score_game(board);
+            int i = (score > 0)? 0 : 1;
+            if((score > 0) == (total%2==0)) ++p1;
+            else ++p2;
+            ++total;
+            fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
+            int j;
+            for(; i < count; i += 2){
+                for(j = 0; j < 93; ++j){
+                    printf("%c", boards[i][j]);
+                }
+                printf("\n");
+            }
+            memset(board, 0, 19*19*sizeof(float));
+            player = 1;
+            done = 0;
+            count = 0;
+            fflush(stdout);
+            fflush(stderr);
+        }
+        //print_board(board, 1, 0);
+        //sleep(1);
+        network use = ((total%2==0) == (player==1)) ? net : net2;
+        int index = generate_move(use, player, board, multi, .1, .7, two, 0);
+        if(index < 0){
+            done = 1;
+            continue;
+        }
+        int row = index / 19;
+        int col = index % 19;
+
+        char *swap = two;
+        two = one;
+        one = swap;
+
+        if(player < 0) flip_board(board);
+        boards[count][0] = row;
+        boards[count][1] = col;
+        board_to_string(boards[count] + 2, board);
+        if(player < 0) flip_board(board);
+        ++count;
+
+        move_go(board, player, row, col);
+        board_to_string(one, board);
+
+        player = -player;
+    }
 }
 
 void run_go(int argc, char **argv)
 {
+    //boards_go();
     if(argc < 4){
         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
         return;
@@ -245,8 +833,14 @@
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
+    char *c2 = (argc > 5) ? argv[5] : 0;
+    char *w2 = (argc > 6) ? argv[6] : 0;
+    int multi = find_arg(argc, argv, "-multi");
     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
-    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights);
+    else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
+    else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
+    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
+    else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
 }
 
 

--
Gitblit v1.10.0