From d8adaf8ea6a31a380f6bf1fe65e88b661d3bb51e Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 21 Oct 2016 20:16:43 +0000
Subject: [PATCH] tree stuff
---
src/softmax_layer.c | 85 ++++++++++++++++++++++++++++++++----------
1 files changed, 65 insertions(+), 20 deletions(-)
diff --git a/src/softmax_layer.c b/src/softmax_layer.c
index 20bc07f..2a34cae 100644
--- a/src/softmax_layer.c
+++ b/src/softmax_layer.c
@@ -32,31 +32,25 @@
return l;
}
-void softmax_array(float *input, int n, float temp, float *output)
-{
- int i;
- float sum = 0;
- float largest = -FLT_MAX;
- for(i = 0; i < n; ++i){
- if(input[i] > largest) largest = input[i];
- }
- for(i = 0; i < n; ++i){
- sum += exp(input[i]/temp-largest/temp);
- }
- if(sum) sum = largest/temp+log(sum);
- else sum = largest-100;
- for(i = 0; i < n; ++i){
- output[i] = exp(input[i]/temp-sum);
- }
-}
-
void forward_softmax_layer(const softmax_layer l, network_state state)
{
int b;
int inputs = l.inputs / l.groups;
int batch = l.batch * l.groups;
- for(b = 0; b < batch; ++b){
- softmax_array(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs);
+ if(l.softmax_tree){
+ for(b = 0; b < batch; ++b){
+ int i;
+ int count = 0;
+ for(i = 0; i < l.softmax_tree->groups; ++i){
+ int group_size = l.softmax_tree->group_size[i];
+ softmax(state.input+b*inputs + count, group_size, l.temperature, l.output+b*inputs + count);
+ count += group_size;
+ }
+ }
+ } else {
+ for(b = 0; b < batch; ++b){
+ softmax(state.input+b*inputs, inputs, l.temperature, l.output+b*inputs);
+ }
}
}
@@ -68,3 +62,54 @@
}
}
+#ifdef GPU
+
+void pull_softmax_layer_output(const softmax_layer layer)
+{
+ cuda_pull_array(layer.output_gpu, layer.output, layer.inputs*layer.batch);
+}
+
+void forward_softmax_layer_gpu(const softmax_layer l, network_state state)
+{
+ int inputs = l.inputs / l.groups;
+ int batch = l.batch * l.groups;
+ int b;
+ if(l.softmax_tree){
+ if(0){
+ float *buff = calloc(inputs * batch, sizeof(float));
+ cuda_pull_array(state.input, buff, batch * inputs);
+ state.input = buff;
+ forward_softmax_layer(l, state);
+ cuda_push_array(l.output_gpu, l.output, batch*inputs);
+ free(buff);
+ } else {
+ int i;
+ const int nstreams = 32;
+ cudaStream_t streams[nstreams];
+ for (i = 0; i < nstreams; ++i) {
+ cudaStreamCreate(&streams[i]);
+ }
+ for (b = 0; b < batch; ++b) {
+ int i;
+ int count = 0;
+ for (i = 0; i < l.softmax_tree->groups; ++i) {
+ int group_size = l.softmax_tree->group_size[i];
+ softmax_gpu(state.input+b*inputs + count, group_size, 1, l.temperature, l.output_gpu+b*inputs + count, streams[(b*l.softmax_tree->groups + i) % nstreams]);
+ count += group_size;
+ }
+ }
+ for(i = 0; i < nstreams; ++i){
+ cudaStreamDestroy(streams[i]);
+ }
+ }
+ } else {
+ softmax_gpu(state.input, inputs, batch, l.temperature, l.output_gpu, 0);
+ }
+}
+
+void backward_softmax_layer_gpu(const softmax_layer layer, network_state state)
+{
+ axpy_ongpu(layer.batch*layer.inputs, 1, layer.delta_gpu, 1, state.delta, 1);
+}
+
+#endif
--
Gitblit v1.10.0