| | |
| | | #include "opencv2/highgui/highgui_c.h" |
| | | #endif |
| | | |
| | | int inverted = 0; |
| | | int noi = 0; |
| | | |
| | | void train_go(char *cfgfile, char *weightfile) |
| | | { |
| | | data_seed = time(0); |
| | |
| | | |
| | | char *backup_directory = "/home/pjreddie/backup/"; |
| | | |
| | | data train = load_go("/home/pjreddie/backup/go.train"); |
| | | |
| | | char buff[256]; |
| | | sprintf(buff, "/home/pjreddie/go.train.%02d", rand()%10); |
| | | data train = load_go(buff); |
| | | |
| | | int N = train.X.rows; |
| | | int epoch = (*net.seen)/N; |
| | | while(get_current_batch(net) < net.max_batches || net.max_batches == 0){ |
| | |
| | | char buff[256]; |
| | | sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch); |
| | | save_weights(net, buff); |
| | | |
| | | free_data(train); |
| | | sprintf(buff, "/home/pjreddie/go.train.%02d", epoch%10); |
| | | train = load_go(buff); |
| | | } |
| | | if(get_current_batch(net)%100 == 0){ |
| | | char buff[256]; |
| | |
| | | save_weights(net, buff); |
| | | } |
| | | } |
| | | char buff[256]; |
| | | sprintf(buff, "%s/%s.weights", backup_directory, base); |
| | | save_weights(net, buff); |
| | | |
| | |
| | | free(l); |
| | | } |
| | | |
| | | void print_board(float *board) |
| | | void print_board(float *board, int swap, int *indexes) |
| | | { |
| | | int i,j; |
| | | int i,j,n; |
| | | printf("\n\n"); |
| | | printf(" "); |
| | | for(i = 0; i < 19; ++i){ |
| | | printf("%c ", 'A' + i + 1*(i > 7)); |
| | | printf("%c ", 'A' + i + 1*(i > 7 && noi)); |
| | | } |
| | | printf("\n"); |
| | | for(j = 0; j < 19; ++j){ |
| | | printf("%2d ", 19-j); |
| | | printf("%2d ", (inverted) ? 19-j : j+1); |
| | | for(i = 0; i < 19; ++i){ |
| | | int index = j*19 + i; |
| | | if(board[index] > 0) printf("\u25C9 "); |
| | | else if(board[index] < 0) printf("\u25EF "); |
| | | if(indexes){ |
| | | int found = 0; |
| | | for(n = 0; n < 3; ++n){ |
| | | if(index == indexes[n]){ |
| | | found = 1; |
| | | if(n == 0) printf("\uff11"); |
| | | else if(n == 1) printf("\uff12"); |
| | | else if(n == 2) printf("\uff13"); |
| | | } |
| | | } |
| | | if(found) continue; |
| | | } |
| | | if(board[index]*-swap > 0) printf("\u25C9 "); |
| | | else if(board[index]*-swap < 0) printf("\u25EF "); |
| | | else printf(" "); |
| | | } |
| | | printf("\n"); |
| | |
| | | set_batch_network(&net, 1); |
| | | float *board = calloc(19*19, sizeof(float)); |
| | | float *move = calloc(19*19, sizeof(float)); |
| | | int color = 1; |
| | | while(1){ |
| | | float *output = network_predict(net, board); |
| | | copy_cpu(19*19, output, 1, move, 1); |
| | | int i; |
| | | #ifdef GPU |
| | | #ifdef GPU |
| | | image bim = float_to_image(19, 19, 1, board); |
| | | for(i = 1; i < 8; ++i){ |
| | | rotate_image_cw(bim, i); |
| | |
| | | rotate_image_cw(bim, -i); |
| | | } |
| | | scal_cpu(19*19, 1./8., move, 1); |
| | | #endif |
| | | #endif |
| | | for(i = 0; i < 19*19; ++i){ |
| | | if(board[i]) move[i] = 0; |
| | | } |
| | |
| | | int indexes[3]; |
| | | int row, col; |
| | | top_k(move, 19*19, 3, indexes); |
| | | print_board(board); |
| | | print_board(board, color, indexes); |
| | | for(i = 0; i < 3; ++i){ |
| | | int index = indexes[i]; |
| | | row = index / 19; |
| | | col = index % 19; |
| | | printf("Suggested: %c %d, %.2f%%\n", col + 'A' + 1*(col > 7), 19 - row, move[index]*100); |
| | | printf("Suggested: %c %d, %.2f%%\n", col + 'A' + 1*(col > 7 && noi), (inverted)?19 - row : row-1, move[index]*100); |
| | | } |
| | | int index = indexes[0]; |
| | | int rec_row = index / 19; |
| | | int rec_col = index % 19; |
| | | |
| | | printf("\u25C9 Enter move: "); |
| | | if(color == 1) printf("\u25EF Enter move: "); |
| | | else printf("\u25C9 Enter move: "); |
| | | |
| | | char c; |
| | | char *line = fgetl(stdin); |
| | | int num = sscanf(line, "%c %d", &c, &row); |
| | |
| | | }else if (c < 'A' || c > 'T'){ |
| | | if (c == 'p'){ |
| | | flip_board(board); |
| | | color = -color; |
| | | free(line); |
| | | continue; |
| | | } else if(c=='b' || c == 'w'){ |
| | | char g; |
| | | num = sscanf(line, "%c %c %d", &g, &c, &row); |
| | | row = (inverted)?19 - row : row-1; |
| | | col = c - 'A'; |
| | | if (col > 7 && noi) col -= 1; |
| | | if (num == 3) board[row*19 + col] = (g == 'b') ? color : -color; |
| | | }else{ |
| | | char g; |
| | | num = sscanf(line, "%c %c %d", &g, &c, &row); |
| | | row = 19 - row; |
| | | row = (inverted)?19 - row : row-1; |
| | | col = c - 'A'; |
| | | if (col > 7) col -= 1; |
| | | if (num == 2) board[row*19 + col] = 0; |
| | | if (col > 7 && noi) col -= 1; |
| | | if (num == 3) board[row*19 + col] = 0; |
| | | } |
| | | } else if(num == 2){ |
| | | row = 19 - row; |
| | | row = (inverted)?19 - row : row-1; |
| | | col = c - 'A'; |
| | | if (col > 7) col -= 1; |
| | | if (col > 7 && noi) col -= 1; |
| | | board[row*19 + col] = 1; |
| | | }else{ |
| | | free(line); |
| | | continue; |
| | | } |
| | | free(line); |
| | | update_board(board); |
| | | flip_board(board); |
| | | color = -color; |
| | | } |
| | | |
| | | } |