From f11480833d19c0a7e9e1f7b45a19ba5bb5d63666 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 02 Aug 2015 00:26:53 +0000
Subject: [PATCH] Headers are important
---
src/network.c | 24 ++++++++++++++++--------
1 files changed, 16 insertions(+), 8 deletions(-)
diff --git a/src/network.c b/src/network.c
index 5b52da9..ff5cd61 100644
--- a/src/network.c
+++ b/src/network.c
@@ -4,6 +4,7 @@
#include "image.h"
#include "data.h"
#include "utils.h"
+#include "blas.h"
#include "crop_layer.h"
#include "connected_layer.h"
@@ -125,13 +126,20 @@
float get_network_cost(network net)
{
- if(net.layers[net.n-1].type == COST){
- return net.layers[net.n-1].output[0];
+ int i;
+ float sum = 0;
+ int count = 0;
+ for(i = 0; i < net.n; ++i){
+ if(net.layers[net.n-1].type == COST){
+ sum += net.layers[net.n-1].output[0];
+ ++count;
+ }
+ if(net.layers[net.n-1].type == DETECTION){
+ sum += net.layers[net.n-1].cost[0];
+ ++count;
+ }
}
- if(net.layers[net.n-1].type == DETECTION){
- return net.layers[net.n-1].cost[0];
- }
- return 0;
+ return sum/count;
}
int get_predicted_class_network(network net)
@@ -184,9 +192,9 @@
float train_network_datum(network net, float *x, float *y)
{
- #ifdef GPU
+#ifdef GPU
if(gpu_index >= 0) return train_network_datum_gpu(net, x, y);
- #endif
+#endif
network_state state;
state.input = x;
state.delta = 0;
--
Gitblit v1.10.0