Joseph Redmon
2016-03-14 68213b835b9f15cb449ad2037a8b51c17a3de07b
src/softmax_layer_kernels.cu
@@ -8,7 +8,7 @@
#include "blas.h"
}
__global__ void forward_softmax_layer_kernel(int n, int batch, float *input, float *output)
__global__ void forward_softmax_layer_kernel(int n, int batch, float *input, float temp, float *output)
{
    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(b >= batch) return;
@@ -21,11 +21,11 @@
        largest = (val>largest) ? val : largest;
    }
    for(i = 0; i < n; ++i){
        sum += exp(input[i+b*n]-largest);
        sum += exp(input[i+b*n]/temp-largest/temp);
    }
    sum = (sum != 0) ? largest+log(sum) : largest-100;
    sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
    for(i = 0; i < n; ++i){
        output[i+b*n] = exp(input[i+b*n]-sum);
        output[i+b*n] = exp(input[i+b*n]/temp-sum);
    }
}
@@ -38,7 +38,7 @@
{
    int inputs = layer.inputs / layer.groups;
    int batch = layer.batch * layer.groups;
    forward_softmax_layer_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, batch, state.input, layer.output_gpu);
    forward_softmax_layer_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, batch, state.input, layer.temperature, layer.output_gpu);
    check_error(cudaPeekAtLastError());
}