From 4ab366a805a7678642539465d68ef906b4599aeb Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 22 Dec 2014 22:35:37 +0000
Subject: [PATCH] some fixes, some other experiments

---
 src/network.c |   33 +++++++++++++++++++++++++++++++--
 1 files changed, 31 insertions(+), 2 deletions(-)

diff --git a/src/network.c b/src/network.c
index 829bb6e..42253dc 100644
--- a/src/network.c
+++ b/src/network.c
@@ -74,6 +74,7 @@
             if(!train) continue;
             dropout_layer layer = *(dropout_layer *)net.layers[i];
             forward_dropout_layer(layer, input);
+            input = layer.output;
         }
         else if(net.types[i] == FREEWEIGHT){
             if(!train) continue;
@@ -102,7 +103,8 @@
         }
         else if(net.types[i] == CONNECTED){
             connected_layer layer = *(connected_layer *)net.layers[i];
-            update_connected_layer(layer);
+            secret_update_connected_layer((connected_layer *)net.layers[i]);
+            //update_connected_layer(layer);
         }
     }
 }
@@ -119,7 +121,8 @@
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.output;
     } else if(net.types[i] == DROPOUT){
-        return get_network_output_layer(net, i-1);
+        dropout_layer layer = *(dropout_layer *)net.layers[i];
+        return layer.output;
     } else if(net.types[i] == FREEWEIGHT){
         return get_network_output_layer(net, i-1);
     } else if(net.types[i] == CONNECTED){
@@ -153,6 +156,7 @@
         softmax_layer layer = *(softmax_layer *)net.layers[i];
         return layer.delta;
     } else if(net.types[i] == DROPOUT){
+        if(i == 0) return 0;
         return get_network_delta_layer(net, i-1);
     } else if(net.types[i] == FREEWEIGHT){
         return get_network_delta_layer(net, i-1);
@@ -645,6 +649,31 @@
     }
 }
 
+void compare_networks(network n1, network n2, data test)
+{
+    matrix g1 = network_predict_data(n1, test);
+    matrix g2 = network_predict_data(n2, test);
+    int i;
+    int a,b,c,d;
+    a = b = c = d = 0;
+    for(i = 0; i < g1.rows; ++i){
+        int truth = max_index(test.y.vals[i], test.y.cols);
+        int p1 = max_index(g1.vals[i], g1.cols);
+        int p2 = max_index(g2.vals[i], g2.cols);
+        if(p1 == truth){
+            if(p2 == truth) ++d;
+            else ++c;
+        }else{
+            if(p2 == truth) ++b;
+            else ++a;
+        }
+    }
+    printf("%5d %5d\n%5d %5d\n", a, b, c, d);
+    float num = pow((abs(b - c) - 1.), 2.);
+    float den = b + c;
+    printf("%f\n", num/den); 
+}
+
 float network_accuracy(network net, data d)
 {
     matrix guess = network_predict_data(net, d);

--
Gitblit v1.10.0