From 67794a52a1ca19275f186dbc21cb45c1a45d6b92 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 16 Mar 2016 11:44:44 +0000
Subject: [PATCH] more go

---
 src/go.c |   33 +++++++++++++++++----------------
 1 files changed, 17 insertions(+), 16 deletions(-)

diff --git a/src/go.c b/src/go.c
index 6607e7a..9d31539 100644
--- a/src/go.c
+++ b/src/go.c
@@ -176,7 +176,7 @@
     }
 }
 
-void test_go(char *filename, char *weightfile)
+void test_go(char *filename, char *weightfile, int multi)
 {
     network net = parse_network_cfg(filename);
     if(weightfile){
@@ -191,25 +191,25 @@
         float *output = network_predict(net, board);
         copy_cpu(19*19, output, 1, move, 1);
         int i;
-#ifdef GPU
-        image bim = float_to_image(19, 19, 1, board);
-        for(i = 1; i < 8; ++i){
-            rotate_image_cw(bim, i);
-            if(i >= 4) flip_image(bim);
+        if(multi){
+            image bim = float_to_image(19, 19, 1, board);
+            for(i = 1; i < 8; ++i){
+                rotate_image_cw(bim, i);
+                if(i >= 4) flip_image(bim);
 
-            float *output = network_predict(net, board);
-            image oim = float_to_image(19, 19, 1, output);
+                float *output = network_predict(net, board);
+                image oim = float_to_image(19, 19, 1, output);
 
-            if(i >= 4) flip_image(oim);
-            rotate_image_cw(oim, -i);
+                if(i >= 4) flip_image(oim);
+                rotate_image_cw(oim, -i);
 
-            axpy_cpu(19*19, 1, output, 1, move, 1);
+                axpy_cpu(19*19, 1, output, 1, move, 1);
 
-            if(i >= 4) flip_image(bim);
-            rotate_image_cw(bim, -i);
+                if(i >= 4) flip_image(bim);
+                rotate_image_cw(bim, -i);
+            }
+            scal_cpu(19*19, 1./8., move, 1);
         }
-        scal_cpu(19*19, 1./8., move, 1);
-#endif
         for(i = 0; i < 19*19; ++i){
             if(board[i]) move[i] = 0;
         }
@@ -282,8 +282,9 @@
 
     char *cfg = argv[3];
     char *weights = (argc > 4) ? argv[4] : 0;
+    int multi = find_arg(argc, argv, "-multi");
     if(0==strcmp(argv[2], "train")) train_go(cfg, weights);
-    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights);
+    else if(0==strcmp(argv[2], "test")) test_go(cfg, weights, multi);
 }
 
 

--
Gitblit v1.10.0