From 2ce6460c79e06caa33eab3991ee3e7fd9f0909d6 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Tue, 10 Apr 2018 18:08:24 +0000
Subject: [PATCH] Remove truth only if smaller than 1x1 pix during training

---
 src/gru_layer.c |   88 ++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 88 insertions(+), 0 deletions(-)

diff --git a/src/gru_layer.c b/src/gru_layer.c
index 1c41cbf..b78e868 100644
--- a/src/gru_layer.c
+++ b/src/gru_layer.c
@@ -76,8 +76,24 @@
     l.outputs = outputs;
     l.output = calloc(outputs*batch*steps, sizeof(float));
     l.delta = calloc(outputs*batch*steps, sizeof(float));
+    l.state = calloc(outputs*batch, sizeof(float));
+    l.prev_state = calloc(outputs*batch, sizeof(float));
+    l.forgot_state = calloc(outputs*batch, sizeof(float));
+    l.forgot_delta = calloc(outputs*batch, sizeof(float));
+
+    l.r_cpu = calloc(outputs*batch, sizeof(float));
+    l.z_cpu = calloc(outputs*batch, sizeof(float));
+    l.h_cpu = calloc(outputs*batch, sizeof(float));
+
+    l.forward = forward_gru_layer;
+    l.backward = backward_gru_layer;
+    l.update = update_gru_layer;
 
 #ifdef GPU
+    l.forward_gpu = forward_gru_layer_gpu;
+    l.backward_gpu = backward_gru_layer_gpu;
+    l.update_gpu = update_gru_layer_gpu;
+
     l.forgot_state_gpu = cuda_make_array(l.output, batch*outputs);
     l.forgot_delta_gpu = cuda_make_array(l.output, batch*outputs);
     l.prev_state_gpu = cuda_make_array(l.output, batch*outputs);
@@ -101,6 +117,78 @@
 
 void forward_gru_layer(layer l, network_state state)
 {
+    network_state s = {0};
+    s.train = state.train;
+    int i;
+    layer input_z_layer = *(l.input_z_layer);
+    layer input_r_layer = *(l.input_r_layer);
+    layer input_h_layer = *(l.input_h_layer);
+
+    layer state_z_layer = *(l.state_z_layer);
+    layer state_r_layer = *(l.state_r_layer);
+    layer state_h_layer = *(l.state_h_layer);
+
+    fill_cpu(l.outputs * l.batch * l.steps, 0, input_z_layer.delta, 1);
+    fill_cpu(l.outputs * l.batch * l.steps, 0, input_r_layer.delta, 1);
+    fill_cpu(l.outputs * l.batch * l.steps, 0, input_h_layer.delta, 1);
+
+    fill_cpu(l.outputs * l.batch * l.steps, 0, state_z_layer.delta, 1);
+    fill_cpu(l.outputs * l.batch * l.steps, 0, state_r_layer.delta, 1);
+    fill_cpu(l.outputs * l.batch * l.steps, 0, state_h_layer.delta, 1);
+    if(state.train) {
+        fill_cpu(l.outputs * l.batch * l.steps, 0, l.delta, 1);
+        copy_cpu(l.outputs*l.batch, l.state, 1, l.prev_state, 1);
+    }
+
+    for (i = 0; i < l.steps; ++i) {
+        s.input = l.state;
+        forward_connected_layer(state_z_layer, s);
+        forward_connected_layer(state_r_layer, s);
+
+        s.input = state.input;
+        forward_connected_layer(input_z_layer, s);
+        forward_connected_layer(input_r_layer, s);
+        forward_connected_layer(input_h_layer, s);
+
+
+        copy_cpu(l.outputs*l.batch, input_z_layer.output, 1, l.z_cpu, 1);
+        axpy_cpu(l.outputs*l.batch, 1, state_z_layer.output, 1, l.z_cpu, 1);
+
+        copy_cpu(l.outputs*l.batch, input_r_layer.output, 1, l.r_cpu, 1);
+        axpy_cpu(l.outputs*l.batch, 1, state_r_layer.output, 1, l.r_cpu, 1);
+
+        activate_array(l.z_cpu, l.outputs*l.batch, LOGISTIC);
+        activate_array(l.r_cpu, l.outputs*l.batch, LOGISTIC);
+
+        copy_cpu(l.outputs*l.batch, l.state, 1, l.forgot_state, 1);
+        mul_cpu(l.outputs*l.batch, l.r_cpu, 1, l.forgot_state, 1);
+
+        s.input = l.forgot_state;
+        forward_connected_layer(state_h_layer, s);
+
+        copy_cpu(l.outputs*l.batch, input_h_layer.output, 1, l.h_cpu, 1);
+        axpy_cpu(l.outputs*l.batch, 1, state_h_layer.output, 1, l.h_cpu, 1);
+
+        #ifdef USET
+        activate_array(l.h_cpu, l.outputs*l.batch, TANH);
+        #else
+        activate_array(l.h_cpu, l.outputs*l.batch, LOGISTIC);
+        #endif
+
+        weighted_sum_cpu(l.state, l.h_cpu, l.z_cpu, l.outputs*l.batch, l.output);
+
+        copy_cpu(l.outputs*l.batch, l.output, 1, l.state, 1);
+
+        state.input += l.inputs*l.batch;
+        l.output += l.outputs*l.batch;
+        increment_layer(&input_z_layer, 1);
+        increment_layer(&input_r_layer, 1);
+        increment_layer(&input_h_layer, 1);
+
+        increment_layer(&state_z_layer, 1);
+        increment_layer(&state_r_layer, 1);
+        increment_layer(&state_h_layer, 1);
+    }
 }
 
 void backward_gru_layer(layer l, network_state state)

--
Gitblit v1.10.0