From 91f95c715bff84094fc18bad6a8f938291b9b0f5 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 24 Oct 2016 20:32:49 +0000
Subject: [PATCH] tree things, tree stuff

---
 src/blas_kernels.cu |   22 +++++++++++++---------
 1 files changed, 13 insertions(+), 9 deletions(-)

diff --git a/src/blas_kernels.cu b/src/blas_kernels.cu
index 59ec005..b4d520e 100644
--- a/src/blas_kernels.cu
+++ b/src/blas_kernels.cu
@@ -693,31 +693,35 @@
 }
 
 
-__global__ void softmax_kernel(int n, int batch, float *input, float temp, float *output)
+__device__ void softmax_device(int n, float *input, float temp, float *output)
 {
-    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
-    if(b >= batch) return;
-
     int i;
     float sum = 0;
     float largest = -INFINITY;
     for(i = 0; i < n; ++i){
-        int val = input[i+b*n];
+        int val = input[i];
         largest = (val>largest) ? val : largest;
     }
     for(i = 0; i < n; ++i){
-        sum += exp(input[i+b*n]/temp-largest/temp);
+        sum += exp(input[i]/temp-largest/temp);
     }
     sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
     for(i = 0; i < n; ++i){
-        output[i+b*n] = exp(input[i+b*n]/temp-sum);
+        output[i] = exp(input[i]/temp-sum);
     }
 }
 
-extern "C" void softmax_gpu(float *input, int n, int groups, float temp, float *output, cudaStream_t stream)
+__global__ void softmax_kernel(int n, int offset, int batch, float *input, float temp, float *output)
+{
+    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if(b >= batch) return;
+    softmax_device(n, input + b*offset, temp, output + b*offset);
+}
+
+extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output)
 {
     int inputs = n;
     int batch = groups;
-    softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, stream>>>(inputs, batch, input, temp, output);
+    softmax_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, offset, batch, input, temp, output);
     check_error(cudaPeekAtLastError());
 }

--
Gitblit v1.10.0