From 1b94df24fde6dea36d85b1ea7873a83e1a213277 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 17 Jul 2014 16:05:07 +0000
Subject: [PATCH] Midway through lots of fixes, checkpoint

---
 src/network.c |   26 ++++++++++++++------------
 1 files changed, 14 insertions(+), 12 deletions(-)

diff --git a/src/network.c b/src/network.c
index ef80110..6855c55 100644
--- a/src/network.c
+++ b/src/network.c
@@ -272,7 +272,9 @@
     for(i = 0; i < get_network_output_size(net)*net.batch; ++i){
         //if(i %get_network_output_size(net) == 0) printf("\n");
         //printf("%5.2f %5.2f, ", out[i], truth[i]);
+        //if(i == get_network_output_size(net)) printf("\n");
         delta[i] = truth[i] - out[i];
+        //printf("%f, ", delta[i]);
         sum += delta[i]*delta[i];
     }
     //printf("\n");
@@ -382,20 +384,20 @@
 }
 float train_network_batch(network net, data d, int n, float step, float momentum,float decay)
 {
-    int i;
-    int correct = 0;
+    int i,j;
+    float sum = 0;
+    int batch = 2;
     for(i = 0; i < n; ++i){
-        int index = rand()%d.X.rows;
-        float *x = d.X.vals[index];
-        float *y = d.y.vals[index];
-        forward_network(net, x, 1);
-        int class = get_predicted_class_network(net);
-        backward_network(net, x, y);
-        correct += (y[class]?1:0);
+        for(j = 0; j < batch; ++j){
+            int index = rand()%d.X.rows;
+            float *x = d.X.vals[index];
+            float *y = d.y.vals[index];
+            forward_network(net, x, 1);
+            sum += backward_network(net, x, y);
+        }
+        update_network(net, step, momentum, decay);
     }
-    update_network(net, step, momentum, decay);
-    return (float)correct/n;
-
+    return (float)sum/(n*batch);
 }
 
 

--
Gitblit v1.10.0