From ae43c2bc32fbb838bfebeeaf2c2b058ccab5c83c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@burninator.cs.washington.edu>
Date: Thu, 23 Jun 2016 05:31:14 +0000
Subject: [PATCH] hi
---
src/go.c | 289 ++++++++++++++++++++++++++++++++++++++++++++++++++-------
1 files changed, 252 insertions(+), 37 deletions(-)
diff --git a/src/go.c b/src/go.c
index 8d0cf52..91beaf1 100644
--- a/src/go.c
+++ b/src/go.c
@@ -98,6 +98,7 @@
int col = b[1];
labels[col + 19*(row + i*19)] = 1;
string_to_board(b+2, boards+i*19*19);
+ boards[col + 19*(row + i*19)] = 0;
int flip = rand()%2;
int rotate = rand()%4;
@@ -132,6 +133,7 @@
float *board = calloc(19*19*net.batch, sizeof(float));
float *move = calloc(19*19*net.batch, sizeof(float));
moves m = load_go_moves("/home/pjreddie/go.train");
+ //moves m = load_go_moves("games.txt");
int N = m.n;
int epoch = (*net.seen)/N;
@@ -215,7 +217,7 @@
}
fprintf(stream, "\n");
for(j = 0; j < 19; ++j){
- fprintf(stream, "%2d ", (inverted) ? 19-j : j+1);
+ fprintf(stream, "%2d", (inverted) ? 19-j : j+1);
for(i = 0; i < 19; ++i){
int index = j*19 + i;
if(indexes){
@@ -223,17 +225,26 @@
for(n = 0; n < nind; ++n){
if(index == indexes[n]){
found = 1;
+ /*
if(n == 0) fprintf(stream, "\uff11");
else if(n == 1) fprintf(stream, "\uff12");
else if(n == 2) fprintf(stream, "\uff13");
else if(n == 3) fprintf(stream, "\uff14");
else if(n == 4) fprintf(stream, "\uff15");
+ */
+ if(n == 0) fprintf(stream, " 1");
+ else if(n == 1) fprintf(stream, " 2");
+ else if(n == 2) fprintf(stream, " 3");
+ else if(n == 3) fprintf(stream, " 4");
+ else if(n == 4) fprintf(stream, " 5");
}
}
if(found) continue;
}
- if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
- else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+ //if(board[index]*-swap > 0) fprintf(stream, "\u25C9 ");
+ //else if(board[index]*-swap < 0) fprintf(stream, "\u25EF ");
+ if(board[index]*-swap > 0) fprintf(stream, " O");
+ else if(board[index]*-swap < 0) fprintf(stream, " X");
else fprintf(stream, " ");
}
fprintf(stream, "\n");
@@ -337,6 +348,90 @@
return 1;
}
+int generate_move(network net, int player, float *board, int multi, float thresh, float temp, char *ko, int print)
+{
+ int i, j;
+ for(i = 0; i < net.n; ++i) net.layers[i].temperature = temp;
+
+ float move[361];
+ if (player < 0) flip_board(board);
+ predict_move(net, board, move, multi);
+ if (player < 0) flip_board(board);
+
+
+ for(i = 0; i < 19; ++i){
+ for(j = 0; j < 19; ++j){
+ if (!legal_go(board, ko, player, i, j)) move[i*19 + j] = 0;
+ }
+ }
+
+ int indexes[nind];
+ top_k(move, 19*19, nind, indexes);
+ if(thresh > move[indexes[0]]) thresh = move[indexes[nind-1]];
+
+ for(i = 0; i < 19; ++i){
+ for(j = 0; j < 19; ++j){
+ if (move[i*19 + j] < thresh) move[i*19 + j] = 0;
+ }
+ }
+
+
+ int max = max_index(move, 19*19);
+ int row = max / 19;
+ int col = max % 19;
+ int index = sample_array(move, 19*19);
+
+ if(print){
+ top_k(move, 19*19, nind, indexes);
+ for(i = 0; i < nind; ++i){
+ if (!move[indexes[i]]) indexes[i] = -1;
+ }
+ print_board(board, player, indexes);
+ for(i = 0; i < nind; ++i){
+ fprintf(stderr, "%d: %f\n", i+1, move[indexes[i]]);
+ }
+ }
+
+ if(suicide_go(board, player, row, col)){
+ return -1;
+ }
+ if(suicide_go(board, player, index/19, index%19)) index = max;
+ return index;
+}
+
+void valid_go(char *cfgfile, char *weightfile, int multi)
+{
+ data_seed = time(0);
+ srand(time(0));
+ char *base = basecfg(cfgfile);
+ printf("%s\n", base);
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ set_batch_network(&net, 1);
+ printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+
+ float *board = calloc(19*19, sizeof(float));
+ float *move = calloc(19*19, sizeof(float));
+ moves m = load_go_moves("/home/pjreddie/backup/go.test");
+
+ int N = m.n;
+ int i;
+ int correct = 0;
+ for(i = 0; i <N; ++i){
+ char *b = m.data[i];
+ int row = b[0];
+ int col = b[1];
+ int truth = col + 19*row;
+ string_to_board(b+2, board);
+ predict_move(net, board, move, multi);
+ int index = max_index(move, 19*19);
+ if(index == truth) ++correct;
+ printf("%d Accuracy %f\n", i, (float) correct/(i+1));
+ }
+}
+
void engine_go(char *filename, char *weightfile, int multi)
{
network net = parse_network_cfg(filename);
@@ -346,12 +441,10 @@
srand(time(0));
set_batch_network(&net, 1);
float *board = calloc(19*19, sizeof(float));
- float *move = calloc(19*19, sizeof(float));
char *one = calloc(91, sizeof(char));
char *two = calloc(91, sizeof(char));
int passed = 0;
while(1){
- print_board(board, 1, 0);
char buff[256];
int id = 0;
int has_id = (scanf("%d", &id) == 1);
@@ -436,42 +529,34 @@
board_to_string(one, board);
printf("=%s \n\n", ids);
+ print_board(board, 1, 0);
} else if (!strcmp(buff, "genmove")){
char color[256];
scanf("%s", color);
int player = (color[0] == 'b' || color[0] == 'B') ? 1 : -1;
- if(player < 0) flip_board(board);
- predict_move(net, board, move, multi);
- if(player < 0) flip_board(board);
-
- int i, j;
- for(i = 0; i < 19; ++i){
- for(j = 0; j < 19; ++j){
- if (!legal_go(board, two, player, i, j)) move[i*19 + j] = 0;
- }
- }
- int index = max_index(move, 19*19);
- int row = index / 19;
- char col = index % 19;
-
- char *swap = two;
- two = one;
- one = swap;
-
- if(passed || suicide_go(board, player, row, col)){
+ int index = generate_move(net, player, board, multi, .1, .7, two, 1);
+ if(passed || index < 0){
printf("=%s pass\n\n", ids);
passed = 0;
} else {
+ int row = index / 19;
+ int col = index % 19;
+
+ char *swap = two;
+ two = one;
+ one = swap;
+
move_go(board, player, row, col);
board_to_string(one, board);
-
row = 19 - row;
if (col >= 8) ++col;
printf("=%s %c%d\n\n", ids, 'A' + col, row);
+ print_board(board, 1, 0);
}
+
} else if (!strcmp(buff, "p")){
- print_board(board, 1, 0);
+ //print_board(board, 1, 0);
} else if (!strcmp(buff, "final_status_list")){
char type[256];
scanf("%s", type);
@@ -479,7 +564,30 @@
char *line = fgetl(stdin);
free(line);
if(type[0] == 'd' || type[0] == 'D'){
- printf("=%s \n\n", ids);
+ FILE *f = fopen("game.txt", "w");
+ int i, j;
+ int count = 2;
+ fprintf(f, "boardsize 19\n");
+ fprintf(f, "clear_board\n");
+ for(j = 0; j < 19; ++j){
+ for(i = 0; i < 19; ++i){
+ if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+ if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+ if(board[j*19 + i]) ++count;
+ }
+ }
+ fprintf(f, "final_status_list dead\n");
+ fclose(f);
+ FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+ for(i = 0; i < count; ++i){
+ free(fgetl(p));
+ free(fgetl(p));
+ }
+ char *l = 0;
+ while((l = fgetl(p))){
+ printf("%s\n", l);
+ free(l);
+ }
} else {
printf("?%s unknown command\n\n", ids);
}
@@ -541,8 +649,10 @@
col = index % 19;
printf("%d: %c %d, %.2f%%\n", i+1, col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row+1, move[index]*100);
}
- if(color == 1) printf("\u25EF Enter move: ");
- else printf("\u25C9 Enter move: ");
+ //if(color == 1) printf("\u25EF Enter move: ");
+ //else printf("\u25C9 Enter move: ");
+ if(color == 1) printf("X Enter move: ");
+ else printf("O Enter move: ");
char c;
char *line = fgetl(stdin);
@@ -588,17 +698,118 @@
}
}
-void boards_go()
+float score_game(float *board)
{
- moves m = load_go_moves("/home/pjreddie/go.train");
- int i;
- float board[361];
- for(i = 0; i < 10; ++i){
- printf("%d %d\n", m.data[i][0], m.data[i][1]);
- string_to_board(m.data[i]+2, board);
- print_board(board, 1, 0);
+ FILE *f = fopen("game.txt", "w");
+ int i, j;
+ int count = 3;
+ fprintf(f, "komi 6.5\n");
+ fprintf(f, "boardsize 19\n");
+ fprintf(f, "clear_board\n");
+ for(j = 0; j < 19; ++j){
+ for(i = 0; i < 19; ++i){
+ if(board[j*19 + i] == 1) fprintf(f, "play black %c%d\n", 'A'+i+(i>=8), 19-j);
+ if(board[j*19 + i] == -1) fprintf(f, "play white %c%d\n", 'A'+i+(i>=8), 19-j);
+ if(board[j*19 + i]) ++count;
+ }
+ }
+ fprintf(f, "final_score\n");
+ fclose(f);
+ FILE *p = popen("./gnugo --mode gtp < game.txt", "r");
+ for(i = 0; i < count; ++i){
+ free(fgetl(p));
+ free(fgetl(p));
+ }
+ char *l = 0;
+ float score = 0;
+ char player = 0;
+ while((l = fgetl(p))){
+ fprintf(stderr, "%s \t", l);
+ int n = sscanf(l, "= %c+%f", &player, &score);
+ free(l);
+ if (n == 2) break;
+ }
+ if(player == 'W') score = -score;
+ pclose(p);
+ return score;
+}
+
+void self_go(char *filename, char *weightfile, char *f2, char *w2, int multi)
+{
+ network net = parse_network_cfg(filename);
+ if(weightfile){
+ load_weights(&net, weightfile);
}
+ network net2 = net;
+ if(f2){
+ net2 = parse_network_cfg(f2);
+ if(w2){
+ load_weights(&net2, w2);
+ }
+ }
+ srand(time(0));
+ char boards[300][93];
+ int count = 0;
+ set_batch_network(&net, 1);
+ set_batch_network(&net2, 1);
+ float *board = calloc(19*19, sizeof(float));
+ char *one = calloc(91, sizeof(char));
+ char *two = calloc(91, sizeof(char));
+ int done = 0;
+ int player = 1;
+ int p1 = 0;
+ int p2 = 0;
+ int total = 0;
+ while(1){
+ if (done || count >= 300){
+ float score = score_game(board);
+ int i = (score > 0)? 0 : 1;
+ if((score > 0) == (total%2==0)) ++p1;
+ else ++p2;
+ ++total;
+ fprintf(stderr, "Total: %d, Player 1: %f, Player 2: %f\n", total, (float)p1/total, (float)p2/total);
+ int j;
+ for(; i < count; i += 2){
+ for(j = 0; j < 93; ++j){
+ printf("%c", boards[i][j]);
+ }
+ printf("\n");
+ }
+ memset(board, 0, 19*19*sizeof(float));
+ player = 1;
+ done = 0;
+ count = 0;
+ fflush(stdout);
+ fflush(stderr);
+ }
+ //print_board(board, 1, 0);
+ //sleep(1);
+ network use = ((total%2==0) == (player==1)) ? net : net2;
+ int index = generate_move(use, player, board, multi, .1, .7, two, 0);
+ if(index < 0){
+ done = 1;
+ continue;
+ }
+ int row = index / 19;
+ int col = index % 19;
+
+ char *swap = two;
+ two = one;
+ one = swap;
+
+ if(player < 0) flip_board(board);
+ boards[count][0] = row;
+ boards[count][1] = col;
+ board_to_string(boards[count] + 2, board);
+ if(player < 0) flip_board(board);
+ ++count;
+
+ move_go(board, player, row, col);
+ board_to_string(one, board);
+
+ player = -player;
+ }
}
void run_go(int argc, char **argv)
@@ -611,8 +822,12 @@
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
+ char *c2 = (argc > 5) ? argv[5] : 0;
+ char *w2 = (argc > 6) ? argv[6] : 0;
int multi = find_arg(argc, argv, "-multi");
if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
+ else if(0==strcmp(argv[2], "valid")) valid_go(cfg, weights, multi);
+ else if(0==strcmp(argv[2], "self")) self_go(cfg, weights, c2, w2, multi);
else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
else if(0==strcmp(argv[2], "engine")) engine_go(cfg, weights, multi);
}
--
Gitblit v1.10.0