From a6b2511a566f77a0838dc1dd0d5f3e3c49a8faa0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sat, 25 Jun 2016 23:13:54 +0000
Subject: [PATCH] idk
---
src/gru_layer.c | 80 ++++++++++++++++++++++++++++++++++++++++
1 files changed, 80 insertions(+), 0 deletions(-)
diff --git a/src/gru_layer.c b/src/gru_layer.c
index 1c41cbf..4c720ce 100644
--- a/src/gru_layer.c
+++ b/src/gru_layer.c
@@ -76,6 +76,14 @@
l.outputs = outputs;
l.output = calloc(outputs*batch*steps, sizeof(float));
l.delta = calloc(outputs*batch*steps, sizeof(float));
+ l.state = calloc(outputs*batch, sizeof(float));
+ l.prev_state = calloc(outputs*batch, sizeof(float));
+ l.forgot_state = calloc(outputs*batch, sizeof(float));
+ l.forgot_delta = calloc(outputs*batch, sizeof(float));
+
+ l.r_cpu = calloc(outputs*batch, sizeof(float));
+ l.z_cpu = calloc(outputs*batch, sizeof(float));
+ l.h_cpu = calloc(outputs*batch, sizeof(float));
#ifdef GPU
l.forgot_state_gpu = cuda_make_array(l.output, batch*outputs);
@@ -101,6 +109,78 @@
void forward_gru_layer(layer l, network_state state)
{
+ network_state s = {0};
+ s.train = state.train;
+ int i;
+ layer input_z_layer = *(l.input_z_layer);
+ layer input_r_layer = *(l.input_r_layer);
+ layer input_h_layer = *(l.input_h_layer);
+
+ layer state_z_layer = *(l.state_z_layer);
+ layer state_r_layer = *(l.state_r_layer);
+ layer state_h_layer = *(l.state_h_layer);
+
+ fill_cpu(l.outputs * l.batch * l.steps, 0, input_z_layer.delta, 1);
+ fill_cpu(l.outputs * l.batch * l.steps, 0, input_r_layer.delta, 1);
+ fill_cpu(l.outputs * l.batch * l.steps, 0, input_h_layer.delta, 1);
+
+ fill_cpu(l.outputs * l.batch * l.steps, 0, state_z_layer.delta, 1);
+ fill_cpu(l.outputs * l.batch * l.steps, 0, state_r_layer.delta, 1);
+ fill_cpu(l.outputs * l.batch * l.steps, 0, state_h_layer.delta, 1);
+ if(state.train) {
+ fill_cpu(l.outputs * l.batch * l.steps, 0, l.delta, 1);
+ copy_cpu(l.outputs*l.batch, l.state, 1, l.prev_state, 1);
+ }
+
+ for (i = 0; i < l.steps; ++i) {
+ s.input = l.state;
+ forward_connected_layer(state_z_layer, s);
+ forward_connected_layer(state_r_layer, s);
+
+ s.input = state.input;
+ forward_connected_layer(input_z_layer, s);
+ forward_connected_layer(input_r_layer, s);
+ forward_connected_layer(input_h_layer, s);
+
+
+ copy_cpu(l.outputs*l.batch, input_z_layer.output, 1, l.z_cpu, 1);
+ axpy_cpu(l.outputs*l.batch, 1, state_z_layer.output, 1, l.z_cpu, 1);
+
+ copy_cpu(l.outputs*l.batch, input_r_layer.output, 1, l.r_cpu, 1);
+ axpy_cpu(l.outputs*l.batch, 1, state_r_layer.output, 1, l.r_cpu, 1);
+
+ activate_array(l.z_cpu, l.outputs*l.batch, LOGISTIC);
+ activate_array(l.r_cpu, l.outputs*l.batch, LOGISTIC);
+
+ copy_cpu(l.outputs*l.batch, l.state, 1, l.forgot_state, 1);
+ mul_cpu(l.outputs*l.batch, l.r_cpu, 1, l.forgot_state, 1);
+
+ s.input = l.forgot_state;
+ forward_connected_layer(state_h_layer, s);
+
+ copy_cpu(l.outputs*l.batch, input_h_layer.output, 1, l.h_cpu, 1);
+ axpy_cpu(l.outputs*l.batch, 1, state_h_layer.output, 1, l.h_cpu, 1);
+
+ #ifdef USET
+ activate_array(l.h_cpu, l.outputs*l.batch, TANH);
+ #else
+ activate_array(l.h_cpu, l.outputs*l.batch, LOGISTIC);
+ #endif
+
+ weighted_sum_cpu(l.state, l.h_cpu, l.z_cpu, l.outputs*l.batch, l.output);
+
+ copy_cpu(l.outputs*l.batch, l.output, 1, l.state, 1);
+
+ state.input += l.inputs*l.batch;
+ l.output += l.outputs*l.batch;
+ increment_layer(&input_z_layer, 1);
+ increment_layer(&input_r_layer, 1);
+ increment_layer(&input_h_layer, 1);
+
+ increment_layer(&state_z_layer, 1);
+ increment_layer(&state_r_layer, 1);
+ increment_layer(&state_h_layer, 1);
+ }
}
void backward_gru_layer(layer l, network_state state)
--
Gitblit v1.10.0