From 0dff437a692a5f875dd0293c628ac9172f697c69 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 28 Mar 2016 02:10:10 +0000
Subject: [PATCH] fixed old

---
 src/go.c |  472 ++++++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 401 insertions(+), 71 deletions(-)

diff --git a/src/go.c b/src/go.c
index 9d31539..8d0cf52 100644
--- a/src/go.c
+++ b/src/go.c
@@ -12,6 +12,107 @@
 int noi = 1;
 static const int nind = 5;
 
+typedef struct {
+    char **data;
+    int n;
+} moves;
+
+char *fgetgo(FILE *fp)
+{
+    if(feof(fp)) return 0;
+    size_t size = 94;
+    char *line = malloc(size*sizeof(char));
+    if(size != fread(line, sizeof(char), size, fp)){
+        free(line);
+        return 0;
+    }
+
+    return line;
+}
+
+moves load_go_moves(char *filename)
+{
+    moves m;
+    m.n = 128;
+    m.data = calloc(128, sizeof(char*));
+    FILE *fp = fopen(filename, "rb");
+    int count = 0;
+    char *line = 0;
+    while((line = fgetgo(fp))){
+        if(count >= m.n){
+            m.n *= 2;
+            m.data = realloc(m.data, m.n*sizeof(char*));
+        }
+        m.data[count] = line;
+        ++count;
+    }
+    printf("%d\n", count);
+    m.n = count;
+    m.data = realloc(m.data, count*sizeof(char*));
+    return m;
+}
+
+void string_to_board(char *s, float *board)
+{
+    int i, j;
+    //memset(board, 0, 1*19*19*sizeof(float));
+    int count = 0;
+    for(i = 0; i < 91; ++i){
+        char c = s[i];
+        for(j = 0; j < 4; ++j){
+            int me = (c >> (2*j)) & 1;
+            int you = (c >> (2*j + 1)) & 1;
+            if (me) board[count] = 1;
+            else if (you) board[count] = -1;
+            else board[count] = 0;
+            ++count;
+            if(count >= 19*19) break;
+        }
+    }
+}
+
+void board_to_string(char *s, float *board)
+{
+    int i, j;
+    memset(s, 0, (19*19/4+1)*sizeof(char));
+    int count = 0;
+    for(i = 0; i < 91; ++i){
+        for(j = 0; j < 4; ++j){
+            int me = (board[count] == 1);
+            int you = (board[count] == -1);
+            if (me) s[i] = s[i] | (1<<(2*j));
+            if (you) s[i] = s[i] | (1<<(2*j + 1));
+            ++count;
+            if(count >= 19*19) break;
+        }
+    }
+}
+
+void random_go_moves(moves m, float *boards, float *labels, int n)
+{
+    int i;
+    memset(labels, 0, 19*19*n*sizeof(float));
+    for(i = 0; i < n; ++i){
+        char *b = m.data[rand()%m.n];
+        int row = b[0];
+        int col = b[1];
+        labels[col + 19*(row + i*19)] = 1;
+        string_to_board(b+2, boards+i*19*19);
+
+        int flip = rand()%2;
+        int rotate = rand()%4;
+        image in = float_to_image(19, 19, 1, boards+i*19*19);
+        image out = float_to_image(19, 19, 1, labels+i*19*19);
+        if(flip){
+            flip_image(in);
+            flip_image(out);
+        }
+        rotate_image_cw(in, rotate);
+        rotate_image_cw(out, rotate);
+    }
+}
+
+
 void train_go(char *cfgfile, char *weightfile)
 {
     data_seed = time(0);
@@ -27,79 +128,61 @@
 
     char *backup_directory = "/home/pjreddie/backup/";
 
-
     char buff[256];
-    sprintf(buff, "/home/pjreddie/go.train.%02d", rand()%10);
-    data train = load_go(buff);
+    float *board = calloc(19*19*net.batch, sizeof(float));
+    float *move = calloc(19*19*net.batch, sizeof(float));
+    moves m = load_go_moves("/home/pjreddie/go.train");
 
-    int N = train.X.rows;
+    int N = m.n;
     int epoch = (*net.seen)/N;
     while(get_current_batch(net) < net.max_batches || net.max_batches == 0){
         clock_t time=clock();
 
-        data batch = get_random_data(train, net.batch);
-        int i;
-        for(i = 0; i < batch.X.rows; ++i){
-            int flip = rand()%2;
-            int rotate = rand()%4;
-            image in = float_to_image(19, 19, 1, batch.X.vals[i]);
-            image out = float_to_image(19, 19, 1, batch.y.vals[i]);
-            //show_image_normalized(in, "in");
-            //show_image_normalized(out, "out");
-            if(flip){
-                flip_image(in);
-                flip_image(out);
-            }
-            rotate_image_cw(in, rotate);
-            rotate_image_cw(out, rotate);
-            //show_image_normalized(in, "in2");
-            //show_image_normalized(out, "out2");
-            //cvWaitKey(0);
-        }
-        float loss = train_network(net, batch);
-        free_data(batch);
+        random_go_moves(m, board, move, net.batch);
+        float loss = train_network_datum(net, board, move) / net.batch;
         if(avg_loss == -1) avg_loss = loss;
         avg_loss = avg_loss*.95 + loss*.05;
         printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
         if(*net.seen/N > epoch){
             epoch = *net.seen/N;
             char buff[256];
-            sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
+            sprintf(buff, "%s/%s_%d.weights", backup_directory,base, epoch);
             save_weights(net, buff);
 
-            free_data(train);
-            sprintf(buff, "/home/pjreddie/go.train.%02d", epoch%10);
-            train = load_go(buff);
         }
         if(get_current_batch(net)%100 == 0){
             char buff[256];
             sprintf(buff, "%s/%s.backup",backup_directory,base);
             save_weights(net, buff);
         }
+        if(get_current_batch(net)%10000 == 0){
+            char buff[256];
+            sprintf(buff, "%s/%s_%d.backup",backup_directory,base,get_current_batch(net));
+            save_weights(net, buff);
+        }
     }
     sprintf(buff, "%s/%s.weights", backup_directory, base);
     save_weights(net, buff);
 
     free_network(net);
     free(base);
-    free_data(train);
 }
 
-void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int num, int side)
+void propagate_liberty(float *board, int *lib, int *visited, int row, int col, int side)
 {
-    if (!num) return;
     if (row < 0 || row > 18 || col < 0 || col > 18) return;
     int index = row*19 + col;
     if (board[index] != side) return;
     if (visited[index]) return;
     visited[index] = 1;
-    lib[index] += num;
-    propagate_liberty(board, lib, visited, row+1, col, num, side);
-    propagate_liberty(board, lib, visited, row-1, col, num, side);
-    propagate_liberty(board, lib, visited, row, col+1, num, side);
-    propagate_liberty(board, lib, visited, row, col-1, num, side);
+    lib[index] += 1;
+    propagate_liberty(board, lib, visited, row+1, col, side);
+    propagate_liberty(board, lib, visited, row-1, col, side);
+    propagate_liberty(board, lib, visited, row, col+1, side);
+    propagate_liberty(board, lib, visited, row, col-1, side);
 }
 
+
 int *calculate_liberties(float *board)
 {
     int *lib = calloc(19*19, sizeof(int));
@@ -109,41 +192,30 @@
         for(i = 0; i < 19; ++i){
             memset(visited, 0, 19*19*sizeof(int));
             int index = j*19 + i;
-            if(board[index]){
-                int side = board[index];
-                int num = 0;
-                if (i > 0  && board[j*19 + i -  1] == 0) ++num;
-                if (i < 18 && board[j*19 + i +  1] == 0) ++num;
-                if (j > 0  && board[j*19 + i - 19] == 0) ++num;
-                if (j < 18 && board[j*19 + i + 19] == 0) ++num;
-                propagate_liberty(board, lib, visited, j, i, num, side);
+            if(board[index] == 0){
+                if ((i > 0)  && board[index - 1]) propagate_liberty(board, lib, visited, j, i-1, board[index-1]);
+                if ((i < 18) && board[index + 1]) propagate_liberty(board, lib, visited, j, i+1, board[index+1]);
+                if ((j > 0)  && board[index - 19]) propagate_liberty(board, lib, visited, j-1, i, board[index-19]);
+                if ((j < 18) && board[index + 19]) propagate_liberty(board, lib, visited, j+1, i, board[index+19]);
             }
         }
     }
     return lib;
 }
 
-void update_board(float *board)
-{
-    int i;
-    int *l = calculate_liberties(board);
-    for(i = 0; i < 19*19; ++i){
-        if (board[i] < 0 && !l[i]) board[i] = 0;
-    }
-    free(l);
-}
-
 void print_board(float *board, int swap, int *indexes)
 {
+    //FILE *stream = stdout;
+    FILE *stream = stderr;
     int i,j,n;
-    printf("\n\n");
-    printf("   ");
+    fprintf(stream, "\n\n");
+    fprintf(stream, "   ");
     for(i = 0; i < 19; ++i){
-        printf("%c ", 'A' + i + 1*(i > 7 && noi));
+        fprintf(stream, "%c ", 'A' + i + 1*(i > 7 && noi));
     }
-    printf("\n");
+    fprintf(stream, "\n");
     for(j = 0; j < 19; ++j){
-        printf("%2d ", (inverted) ? 19-j : j+1);
+        fprintf(stream, "%2d ", (inverted) ? 19-j : j+1);
         for(i = 0; i < 19; ++i){
             int index = j*19 + i;
             if(indexes){
@@ -151,20 +223,20 @@
                 for(n = 0; n < nind; ++n){
                     if(index == indexes[n]){
                         found = 1;
-                        if(n == 0) printf("\uff11");
-                        else if(n == 1) printf("\uff12");
-                        else if(n == 2) printf("\uff13");
-                        else if(n == 3) printf("\uff14");
-                        else if(n == 4) printf("\uff15");
+                        if(n == 0) fprintf(stream, "\uff11");
+                        else if(n == 1) fprintf(stream, "\uff12");
+                        else if(n == 2) fprintf(stream, "\uff13");
+                        else if(n == 3) fprintf(stream, "\uff14");
+                        else if(n == 4) fprintf(stream, "\uff15");
                     }
                 }
                 if(found) continue;
             }
-            if(board[index]*-swap > 0) printf("\u25C9 ");
-            else if(board[index]*-swap < 0) printf("\u25EF ");
-            else printf("  ");
+            if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
+            else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+            else fprintf(stream, "  ");
         }
-        printf("\n");
+        fprintf(stream, "\n");
     }
 }
 
@@ -176,7 +248,96 @@
     }
 }
 
-void test_go(char *filename, char *weightfile, int multi)
+void predict_move(network net, float *board, float *move, int multi)
+{
+    float *output = network_predict(net, board);
+    copy_cpu(19*19, output, 1, move, 1);
+    int i;
+    if(multi){
+        image bim = float_to_image(19, 19, 1, board);
+        for(i = 1; i < 8; ++i){
+            rotate_image_cw(bim, i);
+            if(i >= 4) flip_image(bim);
+
+            float *output = network_predict(net, board);
+            image oim = float_to_image(19, 19, 1, output);
+
+            if(i >= 4) flip_image(oim);
+            rotate_image_cw(oim, -i);
+
+            axpy_cpu(19*19, 1, output, 1, move, 1);
+
+            if(i >= 4) flip_image(bim);
+            rotate_image_cw(bim, -i);
+        }
+        scal_cpu(19*19, 1./8., move, 1);
+    }
+    for(i = 0; i < 19*19; ++i){
+        if(board[i]) move[i] = 0;
+    }
+}
+
+void remove_connected(float *b, int *lib, int p, int r, int c)
+{
+    if (r < 0 || r >= 19 || c < 0 || c >= 19) return;
+    if (b[r*19 + c] != p) return;
+    if (lib[r*19 + c] != 1) return;
+    b[r*19 + c] = 0;
+    remove_connected(b, lib, p, r+1, c);
+    remove_connected(b, lib, p, r-1, c);
+    remove_connected(b, lib, p, r, c+1);
+    remove_connected(b, lib, p, r, c-1);
+}
+
+
+void move_go(float *b, int p, int r, int c)
+{
+    int *l = calculate_liberties(b);
+    b[r*19 + c] = p;
+    remove_connected(b, l, -p, r+1, c);
+    remove_connected(b, l, -p, r-1, c);
+    remove_connected(b, l, -p, r, c+1);
+    remove_connected(b, l, -p, r, c-1);
+    free(l);
+}
+
+int makes_safe_go(float *b, int *lib, int p, int r, int c){
+    if (r < 0 || r >= 19 || c < 0 || c >= 19) return 0;
+    if (b[r*19 + c] == -p){
+        if (lib[r*19 + c] > 1) return 0;
+        else return 1;
+    }
+    if (b[r*19 + c] == 0) return 1;
+    if (lib[r*19 + c] > 1) return 1;
+    return 0;
+}
+
+int suicide_go(float *b, int p, int r, int c)
+{
+    int *l = calculate_liberties(b);
+    int safe = 0;
+    safe = safe || makes_safe_go(b, l, p, r+1, c);
+    safe = safe || makes_safe_go(b, l, p, r-1, c);
+    safe = safe || makes_safe_go(b, l, p, r, c+1);
+    safe = safe || makes_safe_go(b, l, p, r, c-1);
+    free(l);
+    return !safe;
+}
+
+int legal_go(float *b, char *ko, int p, int r, int c)
+{
+    if (b[r*19 + c]) return 0;
+    char curr[91];
+    char next[91];
+    board_to_string(curr, b);
+    move_go(b, p, r, c);
+    board_to_string(next, b);
+    string_to_board(curr, b);
+    if(memcmp(next, ko, 91) == 0) return 0;
+    return 1;
+}
+
+void engine_go(char *filename, char *weightfile, int multi)
 {
     network net = parse_network_cfg(filename);
     if(weightfile){
@@ -186,6 +347,162 @@
     set_batch_network(&net, 1);
     float *board = calloc(19*19, sizeof(float));
     float *move = calloc(19*19, sizeof(float));
+    char *one = calloc(91, sizeof(char));
+    char *two = calloc(91, sizeof(char));
+    int passed = 0;
+    while(1){
+        print_board(board, 1, 0);
+        char buff[256];
+        int id = 0;
+        int has_id = (scanf("%d", &id) == 1);
+        scanf("%s", buff);
+        if (feof(stdin)) break;
+        char ids[256];
+        sprintf(ids, "%d", id);
+        //fprintf(stderr, "%s\n", buff);
+        if (!has_id) ids[0] = 0;
+        if (!strcmp(buff, "protocol_version")){
+            printf("=%s 2\n\n", ids);
+        } else if (!strcmp(buff, "name")){
+            printf("=%s DarkGo\n\n", ids);
+        } else if (!strcmp(buff, "version")){
+            printf("=%s 1.0\n\n", ids);
+        } else if (!strcmp(buff, "known_command")){
+            char comm[256];
+            scanf("%s", comm);
+            int known = (!strcmp(comm, "protocol_version") || 
+                    !strcmp(comm, "name") || 
+                    !strcmp(comm, "version") || 
+                    !strcmp(comm, "known_command") || 
+                    !strcmp(comm, "list_commands") || 
+                    !strcmp(comm, "quit") || 
+                    !strcmp(comm, "boardsize") || 
+                    !strcmp(comm, "clear_board") || 
+                    !strcmp(comm, "komi") || 
+                    !strcmp(comm, "final_status_list") || 
+                    !strcmp(comm, "play") || 
+                    !strcmp(comm, "genmove"));
+            if(known) printf("=%s true\n\n", ids);
+            else printf("=%s false\n\n", ids);
+        } else if (!strcmp(buff, "list_commands")){
+            printf("=%s protocol_version\nname\nversion\nknown_command\nlist_commands\nquit\nboardsize\nclear_board\nkomi\nplay\ngenmove\nfinal_status_list\n\n", ids);
+        } else if (!strcmp(buff, "quit")){
+            break;
+        } else if (!strcmp(buff, "boardsize")){
+            int boardsize = 0;
+            scanf("%d", &boardsize);
+            //fprintf(stderr, "%d\n", boardsize);
+            if(boardsize != 19){
+                printf("?%s unacceptable size\n\n", ids);
+            } else {
+                printf("=%s \n\n", ids);
+            }
+        } else if (!strcmp(buff, "clear_board")){
+            passed = 0;
+            memset(board, 0, 19*19*sizeof(float));
+            printf("=%s \n\n", ids);
+        } else if (!strcmp(buff, "komi")){
+            float komi = 0;
+            scanf("%f", &komi);
+            printf("=%s \n\n", ids);
+        } else if (!strcmp(buff, "play")){
+            char color[256];
+            scanf("%s ", color);
+            char c;
+            int r;
+            int count = scanf("%c%d", &c, &r);
+            int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
+            if(c == 'p' && count < 2) {
+                passed = 1;
+                printf("=%s \n\n", ids);
+                char *line = fgetl(stdin);
+                free(line);
+                fflush(stdout);
+                fflush(stderr);
+                continue;
+            } else {
+                passed = 0;
+            }
+            if(c >= 'A' && c <= 'Z') c = c - 'A';
+            if(c >= 'a' && c <= 'z') c = c - 'a';
+            if(c >= 8) --c;
+            r = 19 - r;
+            fprintf(stderr, "move: %d %d\n", r, c);
+
+            char *swap = two;
+            two = one;
+            one = swap;
+            move_go(board, player, r, c);
+            board_to_string(one, board);
+
+            printf("=%s \n\n", ids);
+        } else if (!strcmp(buff, "genmove")){
+            char color[256];
+            scanf("%s", color);
+            int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
+
+            if(player < 0) flip_board(board);
+            predict_move(net, board, move, multi);
+            if(player < 0) flip_board(board);
+
+            int i, j;
+            for(i = 0; i < 19; ++i){
+                for(j = 0; j < 19; ++j){
+                    if (!legal_go(board, two, player, i, j)) move[i*19 + j] = 0;
+                }
+            }
+            int index = max_index(move, 19*19);
+            int row = index / 19;
+            char col = index % 19;
+
+            char *swap = two;
+            two = one;
+            one = swap;
+
+            if(passed || suicide_go(board, player, row, col)){
+                printf("=%s pass\n\n", ids);
+                passed = 0;
+            } else {
+                move_go(board, player, row, col);
+                board_to_string(one, board);
+
+                row = 19 - row;
+                if (col >= 8) ++col;
+                printf("=%s %c%d\n\n", ids, 'A' + col, row);
+            }
+        } else if (!strcmp(buff, "p")){
+            print_board(board, 1, 0);
+        } else if (!strcmp(buff, "final_status_list")){
+            char type[256];
+            scanf("%s", type);
+            fprintf(stderr, "final_status\n");
+            char *line = fgetl(stdin);
+            free(line);
+            if(type[0] == 'd' || type[0] == 'D'){
+                printf("=%s \n\n", ids);
+            } else {
+                printf("?%s unknown command\n\n", ids);
+            }
+        } else {
+            char *line = fgetl(stdin);
+            free(line);
+            printf("?%s unknown command\n\n", ids);
+        }
+        fflush(stdout);
+        fflush(stderr);
+    }
+}
+
+void test_go(char *cfg, char *weights, int multi)
+{
+    network net = parse_network_cfg(cfg);
+    if(weights){
+        load_weights(&net, weights);
+    }
+    srand(time(0));
+    set_batch_network(&net, 1);
+    float *board = calloc(19*19, sizeof(float));
+    float *move = calloc(19*19, sizeof(float));
     int color = 1;
     while(1){
         float *output = network_predict(net, board);
@@ -266,15 +583,27 @@
             }
         }
         free(line);
-        update_board(board);
         flip_board(board);
         color = -color;
     }
+}
+
+void boards_go()
+{
+    moves m = load_go_moves("/home/pjreddie/go.train");
+    int i;
+    float board[361];
+    for(i = 0; i < 10; ++i){
+        printf("%d %d\n", m.data[i][0], m.data[i][1]);
+        string_to_board(m.data[i]+2, board);
+        print_board(board, 1, 0);
+    }
 
 }
 
 void run_go(int argc, char **argv)
 {
+    //boards_go();
     if(argc < 4){
         fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
         return;
@@ -285,6 +614,7 @@
     int multi = find_arg(argc, argv, "-multi");
     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
     else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
+    else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
 }
 
 

--
Gitblit v1.10.0