From 336a19f14550ab1adbb0d9599284ac525fb6a5e0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 08 Sep 2016 05:44:41 +0000
Subject: [PATCH] slight changes to default demo settings

---
 src/cuda.c |   50 +++++++++++++++++++++++++++++++++++++++++++++-----
 1 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/src/cuda.c b/src/cuda.c
index 8849fb1..327813d 100644
--- a/src/cuda.c
+++ b/src/cuda.c
@@ -5,19 +5,32 @@
 #include "cuda.h"
 #include "utils.h"
 #include "blas.h"
+#include "assert.h"
 #include <stdlib.h>
+#include <time.h>
 
 
 void check_error(cudaError_t status)
 {
+    cudaError_t status2 = cudaGetLastError();
     if (status != cudaSuccess)
     {   
         const char *s = cudaGetErrorString(status);
         char buffer[256];
         printf("CUDA Error: %s\n", s);
+        assert(0);
         snprintf(buffer, 256, "CUDA Error: %s", s);
         error(buffer);
     } 
+    if (status2 != cudaSuccess)
+    {   
+        const char *s = cudaGetErrorString(status);
+        char buffer[256];
+        printf("CUDA Error Prev: %s\n", s);
+        assert(0);
+        snprintf(buffer, 256, "CUDA Error Prev: %s", s);
+        error(buffer);
+    } 
 }
 
 dim3 cuda_gridsize(size_t n){
@@ -33,6 +46,19 @@
     return d;
 }
 
+#ifdef CUDNN
+cudnnHandle_t cudnn_handle()
+{
+    static int init = 0;
+    static cudnnHandle_t handle;
+    if(!init) {
+        cudnnCreate(&handle);
+        init = 1;
+    }
+    return handle;
+}
+#endif
+
 cublasHandle_t blas_handle()
 {
     static int init = 0;
@@ -44,7 +70,7 @@
     return handle;
 }
 
-float *cuda_make_array(float *x, int n)
+float *cuda_make_array(float *x, size_t n)
 {
     float *x_gpu;
     size_t size = sizeof(float)*n;
@@ -54,10 +80,24 @@
         status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
         check_error(status);
     }
+    if(!x_gpu) error("Cuda malloc failed\n");
     return x_gpu;
 }
 
-float cuda_compare(float *x_gpu, float *x, int n, char *s)
+void cuda_random(float *x_gpu, size_t n)
+{
+    static curandGenerator_t gen;
+    static int init = 0;
+    if(!init){
+        curandCreateGenerator(&gen, CURAND_RNG_PSEUDO_DEFAULT);
+        curandSetPseudoRandomGeneratorSeed(gen, time(0));
+        init = 1;
+    }
+    curandGenerateUniform(gen, x_gpu, n);
+    check_error(cudaPeekAtLastError());
+}
+
+float cuda_compare(float *x_gpu, float *x, size_t n, char *s)
 {
     float *tmp = calloc(n, sizeof(float));
     cuda_pull_array(x_gpu, tmp, n);
@@ -70,7 +110,7 @@
     return err;
 }
 
-int *cuda_make_int_array(int n)
+int *cuda_make_int_array(size_t n)
 {
     int *x_gpu;
     size_t size = sizeof(int)*n;
@@ -85,14 +125,14 @@
     check_error(status);
 }
 
-void cuda_push_array(float *x_gpu, float *x, int n)
+void cuda_push_array(float *x_gpu, float *x, size_t n)
 {
     size_t size = sizeof(float)*n;
     cudaError_t status = cudaMemcpy(x_gpu, x, size, cudaMemcpyHostToDevice);
     check_error(status);
 }
 
-void cuda_pull_array(float *x_gpu, float *x, int n)
+void cuda_pull_array(float *x_gpu, float *x, size_t n)
 {
     size_t size = sizeof(float)*n;
     cudaError_t status = cudaMemcpy(x, x_gpu, size, cudaMemcpyDeviceToHost);

--
Gitblit v1.10.0