From 9b1774bd39d65614cdbd2d4e3815086298008911 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 06 Nov 2013 18:37:37 +0000
Subject: [PATCH] Connected layers work forward and backward!
---
src/network.c | 73 ++++++++++++++++++++++++++++++++++++
1 files changed, 72 insertions(+), 1 deletions(-)
diff --git a/src/network.c b/src/network.c
index e55535c..0a74b63 100644
--- a/src/network.c
+++ b/src/network.c
@@ -8,7 +8,7 @@
void run_network(image input, network net)
{
int i;
- double *input_d = 0;
+ double *input_d = input.data;
for(i = 0; i < net.n; ++i){
if(net.types[i] == CONVOLUTIONAL){
convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -30,6 +30,77 @@
}
}
+void update_network(network net, double step)
+{
+ int i;
+ for(i = 0; i < net.n; ++i){
+ if(net.types[i] == CONVOLUTIONAL){
+ convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+ update_convolutional_layer(layer, step);
+ }
+ else if(net.types[i] == MAXPOOL){
+ //maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+ }
+ else if(net.types[i] == CONNECTED){
+ connected_layer layer = *(connected_layer *)net.layers[i];
+ update_connected_layer(layer, step);
+ }
+ }
+}
+
+void learn_network(image input, network net)
+{
+ int i;
+ image prev;
+ double *prev_p;
+ for(i = net.n-1; i >= 0; --i){
+ if(i == 0){
+ prev = input;
+ prev_p = prev.data;
+ } else if(net.types[i-1] == CONVOLUTIONAL){
+ convolutional_layer layer = *(convolutional_layer *)net.layers[i-1];
+ prev = layer.output;
+ prev_p = prev.data;
+ } else if(net.types[i-1] == MAXPOOL){
+ maxpool_layer layer = *(maxpool_layer *)net.layers[i-1];
+ prev = layer.output;
+ prev_p = prev.data;
+ } else if(net.types[i-1] == CONNECTED){
+ connected_layer layer = *(connected_layer *)net.layers[i-1];
+ prev_p = layer.output;
+ }
+
+ if(net.types[i] == CONVOLUTIONAL){
+ convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+ learn_convolutional_layer(prev, layer);
+ }
+ else if(net.types[i] == MAXPOOL){
+ //maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+ }
+ else if(net.types[i] == CONNECTED){
+ connected_layer layer = *(connected_layer *)net.layers[i];
+ learn_connected_layer(prev_p, layer);
+ }
+ }
+}
+
+double *get_network_output(network net)
+{
+ int i = net.n-1;
+ if(net.types[i] == CONVOLUTIONAL){
+ convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+ return layer.output.data;
+ }
+ else if(net.types[i] == MAXPOOL){
+ maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+ return layer.output.data;
+ }
+ else if(net.types[i] == CONNECTED){
+ connected_layer layer = *(connected_layer *)net.layers[i];
+ return layer.output;
+ }
+ return 0;
+}
image get_network_image(network net)
{
int i;
--
Gitblit v1.10.0