From d7286c273211ffeb1f56594f863d1ee9922be6d4 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 07 Nov 2013 00:09:41 +0000
Subject: [PATCH] Loading may or may not work. But probably.

---
 src/network.c |  119 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 118 insertions(+), 1 deletions(-)

diff --git a/src/network.c b/src/network.c
index e55535c..53184d9 100644
--- a/src/network.c
+++ b/src/network.c
@@ -5,10 +5,19 @@
 #include "convolutional_layer.h"
 #include "maxpool_layer.h"
 
+network make_network(int n)
+{
+    network net;
+    net.n = n;
+    net.layers = calloc(net.n, sizeof(void *));
+    net.types = calloc(net.n, sizeof(LAYER_TYPE));
+    return net;
+}
+
 void run_network(image input, network net)
 {
     int i;
-    double *input_d = 0;
+    double *input_d = input.data;
     for(i = 0; i < net.n; ++i){
         if(net.types[i] == CONVOLUTIONAL){
             convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -30,6 +39,114 @@
     }
 }
 
+void update_network(network net, double step)
+{
+    int i;
+    for(i = 0; i < net.n; ++i){
+        if(net.types[i] == CONVOLUTIONAL){
+            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+            update_convolutional_layer(layer, step);
+        }
+        else if(net.types[i] == MAXPOOL){
+            //maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+        }
+        else if(net.types[i] == CONNECTED){
+            connected_layer layer = *(connected_layer *)net.layers[i];
+            update_connected_layer(layer, step);
+        }
+    }
+}
+
+void learn_network(image input, network net)
+{
+    int i;
+    image prev;
+    double *prev_p;
+    for(i = net.n-1; i >= 0; --i){
+        if(i == 0){
+            prev = input;
+            prev_p = prev.data;
+        } else if(net.types[i-1] == CONVOLUTIONAL){
+            convolutional_layer layer = *(convolutional_layer *)net.layers[i-1];
+            prev = layer.output;
+            prev_p = prev.data;
+        } else if(net.types[i-1] == MAXPOOL){
+            maxpool_layer layer = *(maxpool_layer *)net.layers[i-1];
+            prev = layer.output;
+            prev_p = prev.data;
+        } else if(net.types[i-1] == CONNECTED){
+            connected_layer layer = *(connected_layer *)net.layers[i-1];
+            prev_p = layer.output;
+        }
+
+        if(net.types[i] == CONVOLUTIONAL){
+            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+            learn_convolutional_layer(prev, layer);
+        }
+        else if(net.types[i] == MAXPOOL){
+            //maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+        }
+        else if(net.types[i] == CONNECTED){
+            connected_layer layer = *(connected_layer *)net.layers[i];
+            learn_connected_layer(prev_p, layer);
+        }
+    }
+}
+
+
+double *get_network_output_layer(network net, int i)
+{
+    if(net.types[i] == CONVOLUTIONAL){
+        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+        return layer.output.data;
+    }
+    else if(net.types[i] == MAXPOOL){
+        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+        return layer.output.data;
+    }
+    else if(net.types[i] == CONNECTED){
+        connected_layer layer = *(connected_layer *)net.layers[i];
+        return layer.output;
+    }
+    return 0;
+}
+
+int get_network_output_size_layer(network net, int i)
+{
+    if(net.types[i] == CONVOLUTIONAL){
+        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+        return layer.output.h*layer.output.w*layer.output.c;
+    }
+    else if(net.types[i] == MAXPOOL){
+        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+        return layer.output.h*layer.output.w*layer.output.c;
+    }
+    else if(net.types[i] == CONNECTED){
+        connected_layer layer = *(connected_layer *)net.layers[i];
+        return layer.outputs;
+    }
+    return 0;
+}
+
+double *get_network_output(network net)
+{
+    int i = net.n-1;
+    return get_network_output_layer(net, i);
+}
+
+image get_network_image_layer(network net, int i)
+{
+    if(net.types[i] == CONVOLUTIONAL){
+        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
+        return layer.output;
+    }
+    else if(net.types[i] == MAXPOOL){
+        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
+        return layer.output;
+    }
+    return make_image(0,0,0);
+}
+
 image get_network_image(network net)
 {
     int i;

--
Gitblit v1.10.0