Joseph Redmon
2016-10-24 91f95c715bff84094fc18bad6a8f938291b9b0f5
tree things, tree stuff
5 files modified
53 ■■■■■ changed files
Makefile patch | view | raw | blame | history
src/blas.h 2 ●●● patch | view | raw | blame | history
src/blas_kernels.cu 22 ●●●●● patch | view | raw | blame | history
src/network_kernels.cu 4 ●●●● patch | view | raw | blame | history
src/softmax_layer.c 25 ●●●●● patch | view | raw | blame | history
Makefile
src/blas.h
@@ -77,7 +77,7 @@
void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
void softmax_gpu(float *input, int n, int groups, float temp, float *output, cudaStream_t stream);
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
#endif
#endif
src/blas_kernels.cu
@@ -693,31 +693,35 @@
}
__global__ void softmax_kernel(int n, int batch, float *input, float temp, float *output)
__device__ void softmax_device(int n, float *input, float temp, float *output)
{
    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(b >= batch) return;
    int i;
    float sum = 0;
    float largest = -INFINITY;
    for(i = 0; i < n; ++i){
        int val = input[i+b*n];
        int val = input[i];
        largest = (val>largest) ? val : largest;
    }
    for(i = 0; i < n; ++i){
        sum += exp(input[i+b*n]/temp-largest/temp);
        sum += exp(input[i]/temp-largest/temp);
    }
    sum = (sum != 0) ? largest/temp+log(sum) : largest-100;
    for(i = 0; i < n; ++i){
        output[i+b*n] = exp(input[i+b*n]/temp-sum);
        output[i] = exp(input[i]/temp-sum);
    }
}
extern "C" void softmax_gpu(float *input, int n, int groups, float temp, float *output, cudaStream_t stream)
__global__ void softmax_kernel(int n, int offset, int batch, float *input, float temp, float *output)
{
    int b = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
    if(b >= batch) return;
    softmax_device(n, input + b*offset, temp, output + b*offset);
}
extern "C" void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output)
{
    int inputs = n;
    int batch = groups;
    softmax_kernel<<<cuda_gridsize(batch), BLOCK, 0, stream>>>(inputs, batch, input, temp, output);
    softmax_kernel<<<cuda_gridsize(batch), BLOCK>>>(inputs, offset, batch, input, temp, output);
    check_error(cudaPeekAtLastError());
}
src/network_kernels.cu
@@ -134,6 +134,7 @@
    free(ptr);
    cuda_set_device(args.net.gpu_index);
    *args.err = train_network(args.net, args.d);
    printf("%d\n", args.net.gpu_index);
    return 0;
}
@@ -359,11 +360,14 @@
        //printf("%f\n", errors[i]);
        sum += errors[i];
    }
    //cudaDeviceSynchronize();
    if (get_current_batch(nets[0]) % interval == 0) {
        printf("Syncing... ");
        fflush(stdout);
        sync_nets(nets, n, interval);
        printf("Done!\n");
    }
    //cudaDeviceSynchronize();
    free(threads);
    free(errors);
    return (float)sum/(n);
src/softmax_layer.c
@@ -73,37 +73,16 @@
{
    int inputs = l.inputs / l.groups;
    int batch = l.batch * l.groups;
    int b;
    if(l.softmax_tree){
        if(0){
            float *buff = calloc(inputs * batch, sizeof(float));
            cuda_pull_array(state.input, buff, batch * inputs);
            state.input = buff;
            forward_softmax_layer(l, state);
            cuda_push_array(l.output_gpu, l.output, batch*inputs);
            free(buff);
        } else {
            int i;
            const int nstreams = 32;
            cudaStream_t streams[nstreams];
            for (i = 0; i < nstreams; ++i) {
                cudaStreamCreate(&streams[i]);
            }
            for (b = 0; b < batch; ++b) {
                int i;
                int count = 0;
                for (i = 0; i < l.softmax_tree->groups; ++i) {
                    int group_size = l.softmax_tree->group_size[i];
                    softmax_gpu(state.input+b*inputs + count, group_size, 1, l.temperature, l.output_gpu+b*inputs + count, streams[(b*l.softmax_tree->groups + i) % nstreams]);
            softmax_gpu(state.input+count, group_size, inputs, batch, l.temperature, l.output_gpu + count);
                    count += group_size;
                }
            }
            for(i = 0; i < nstreams; ++i){
                cudaStreamDestroy(streams[i]);
            }
        }
    } else {
        softmax_gpu(state.input, inputs, batch, l.temperature, l.output_gpu, 0);
        softmax_gpu(state.input, inputs, inputs, batch, l.temperature, l.output_gpu);
    }
}