From 73f7aacf35ec9b1d0f9de9ddf38af0889f213e99 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Sep 2016 18:34:49 +0000
Subject: [PATCH] better multigpu

---
 src/network.c |    4 +++-
 1 files changed, 3 insertions(+), 1 deletions(-)

diff --git a/src/network.c b/src/network.c
index c9a198f..72c8943 100644
--- a/src/network.c
+++ b/src/network.c
@@ -1,5 +1,6 @@
 #include <stdio.h>
 #include <time.h>
+#include <assert.h>
 #include "network.h"
 #include "image.h"
 #include "data.h"
@@ -318,11 +319,11 @@
 
 float train_network_datum(network net, float *x, float *y)
 {
-    *net.seen += net.batch;
 #ifdef GPU
     if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
 #endif
     network_state state;
+    *net.seen += net.batch;
     state.index = 0;
     state.net = net;
     state.input = x;
@@ -356,6 +357,7 @@
 
 float train_network(network net, data d)
 {
+    assert(d.X.rows % net.batch == 0);
     int batch = net.batch;
     int n = d.X.rows / batch;
     float *X = calloc(batch*d.X.cols, sizeof(float));

--
Gitblit v1.10.0