From 73f7aacf35ec9b1d0f9de9ddf38af0889f213e99 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Sep 2016 18:34:49 +0000
Subject: [PATCH] better multigpu
---
src/network.c | 4 +++-
1 files changed, 3 insertions(+), 1 deletions(-)
diff --git a/src/network.c b/src/network.c
index c9a198f..72c8943 100644
--- a/src/network.c
+++ b/src/network.c
@@ -1,5 +1,6 @@
#include <stdio.h>
#include <time.h>
+#include <assert.h>
#include "network.h"
#include "image.h"
#include "data.h"
@@ -318,11 +319,11 @@
float train_network_datum(network net, float *x, float *y)
{
- *net.seen += net.batch;
#ifdef GPU
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
#endif
network_state state;
+ *net.seen += net.batch;
state.index = 0;
state.net = net;
state.input = x;
@@ -356,6 +357,7 @@
float train_network(network net, data d)
{
+ assert(d.X.rows % net.batch == 0);
int batch = net.batch;
int n = d.X.rows / batch;
float *X = calloc(batch*d.X.cols, sizeof(float));
--
Gitblit v1.10.0