Joseph Redmon
2016-05-13 881d6ee9b6625ee502cb4f27d9b017a3da78caa7
src/cuda.c
@@ -7,10 +7,12 @@
#include "blas.h"
#include "assert.h"
#include <stdlib.h>
#include <time.h>
void check_error(cudaError_t status)
{
    cudaError_t status2 = cudaGetLastError();
    if (status != cudaSuccess)
    {   
        const char *s = cudaGetErrorString(status);
@@ -20,6 +22,15 @@
        snprintf(buffer, 256, "CUDA Error: %s", s);
        error(buffer);
    } 
    if (status2 != cudaSuccess)
    {
        const char *s = cudaGetErrorString(status);
        char buffer[256];
        printf("CUDA Error Prev: %s\n", s);
        assert(0);
        snprintf(buffer, 256, "CUDA Error Prev: %s", s);
        error(buffer);
    }
}
dim3 cuda_gridsize(size_t n){
@@ -35,6 +46,19 @@
    return d;
}
#ifdef CUDNN
cudnnHandle_t cudnn_handle()
{
    static int init = 0;
    static cudnnHandle_t handle;
    if(!init) {
        cudnnCreate(&handle);
        init = 1;
    }
    return handle;
}
#endif
cublasHandle_t blas_handle()
{
    static int init = 0;
@@ -46,7 +70,7 @@
    return handle;
}
float *cuda_make_array(float *x, int n)
float *cuda_make_array(float *x, size_t n)
{
    float *x_gpu;
    size_t size = sizeof(float)*n;
@@ -56,10 +80,24 @@
        status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
        check_error(status);
    }
    if(!x_gpu) error("Cuda malloc failed\n");
    return x_gpu;
}
float cuda_compare(float *x_gpu, float *x, int n, char *s)
void cuda_random(float *x_gpu, size_t n)
{
    static curandGenerator_t gen;
    static int init = 0;
    if(!init){
        curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
        curandSetPseudoRandomGeneratorSeed(gen, time(0));
        init = 1;
    }
    curandGenerateUniform(gen, x_gpu, n);
    check_error(cudaPeekAtLastError());
}
float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
{
    float *tmp = calloc(n, sizeof(float));
    cuda_pull_array(x_gpu, tmp, n);
@@ -72,7 +110,7 @@
    return err;
}
int *cuda_make_int_array(int n)
int *cuda_make_int_array(size_t n)
{
    int *x_gpu;
    size_t size = sizeof(int)*n;
@@ -87,14 +125,14 @@
    check_error(status);
}
void cuda_push_array(float *x_gpu, float *x, int n)
void cuda_push_array(float *x_gpu, float *x, size_t n)
{
    size_t size = sizeof(float)*n;
    cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
    check_error(status);
}
void cuda_pull_array(float *x_gpu, float *x, int n)
void cuda_pull_array(float *x_gpu, float *x, size_t n)
{
    size_t size = sizeof(float)*n;
    cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);