| | |
| | | } |
| | | } |
| | | |
| | | void test_go(char *filename, char *weightfile) |
| | | void test_go(char *filename, char *weightfile, int multi) |
| | | { |
| | | network net = parse_network_cfg(filename); |
| | | if(weightfile){ |
| | |
| | | float *output = network_predict(net, board); |
| | | copy_cpu(19*19, output, 1, move, 1); |
| | | int i; |
| | | #ifdef GPU |
| | | if(multi){ |
| | | image bim = float_to_image(19, 19, 1, board); |
| | | for(i = 1; i < 8; ++i){ |
| | | rotate_image_cw(bim, i); |
| | |
| | | rotate_image_cw(bim, -i); |
| | | } |
| | | scal_cpu(19*19, 1./8., move, 1); |
| | | #endif |
| | | } |
| | | for(i = 0; i < 19*19; ++i){ |
| | | if(board[i]) move[i] = 0; |
| | | } |
| | |
| | | |
| | | char *cfg = argv[3]; |
| | | char *weights = (argc > 4) ? argv[4] : 0; |
| | | int multi = find_arg(argc, argv, "-multi"); |
| | | if(0==strcmp(argv[2], "train")) train_go(cfg, weights); |
| | | else if(0==strcmp(argv[2], "test")) test_go(cfg, weights); |
| | | else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi); |
| | | } |
| | | |
| | | |