From 845ab7579685b6702c92c1088ec11e71bde51f3c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 05 Aug 2016 22:27:07 +0000
Subject: [PATCH] some more stuff

---
 src/convolutional_layer.c |  127 +++++++++++++++++++++--------------------
 1 files changed, 65 insertions(+), 62 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index e8ae49c..006dc4c 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -104,36 +104,37 @@
 
 size_t get_workspace_size(layer l){
 #ifdef CUDNN
-    size_t most = 0;
-    size_t s = 0;
-    cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
-            l.srcTensorDesc,
-            l.filterDesc,
-            l.convDesc,
-            l.dstTensorDesc,
-            l.fw_algo,
-            &s);
-    if (s > most) most = s;
-    cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
-            l.srcTensorDesc,
-            l.ddstTensorDesc,
-            l.convDesc,
-            l.dfilterDesc,
-            l.bf_algo,
-            &s);
-    if (s > most) most = s;
-    cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
-            l.filterDesc,
-            l.ddstTensorDesc,
-            l.convDesc,
-            l.dsrcTensorDesc,
-            l.bd_algo,
-            &s);
-    if (s > most) most = s;
-    return most;
-#else
+    if(gpu_index >= 0){
+        size_t most = 0;
+        size_t s = 0;
+        cudnnGetConvolutionForwardWorkspaceSize(cudnn_handle(),
+                l.srcTensorDesc,
+                l.filterDesc,
+                l.convDesc,
+                l.dstTensorDesc,
+                l.fw_algo,
+                &s);
+        if (s > most) most = s;
+        cudnnGetConvolutionBackwardFilterWorkspaceSize(cudnn_handle(),
+                l.srcTensorDesc,
+                l.ddstTensorDesc,
+                l.convDesc,
+                l.dfilterDesc,
+                l.bf_algo,
+                &s);
+        if (s > most) most = s;
+        cudnnGetConvolutionBackwardDataWorkspaceSize(cudnn_handle(),
+                l.filterDesc,
+                l.ddstTensorDesc,
+                l.convDesc,
+                l.dsrcTensorDesc,
+                l.bd_algo,
+                &s);
+        if (s > most) most = s;
+        return most;
+    }
+    #endif
     return (size_t)l.out_h*l.out_w*l.size*l.size*l.c*sizeof(float);
-#endif
 }
 
 #ifdef GPU
@@ -240,49 +241,51 @@
     }
 
 #ifdef GPU
-    l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
-    l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
+    if(gpu_index >= 0){
+        l.filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        l.filter_updates_gpu = cuda_make_array(l.filter_updates, c*n*size*size);
 
-    l.biases_gpu = cuda_make_array(l.biases, n);
-    l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
+        l.biases_gpu = cuda_make_array(l.biases, n);
+        l.bias_updates_gpu = cuda_make_array(l.bias_updates, n);
 
-    l.scales_gpu = cuda_make_array(l.scales, n);
-    l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
+        l.scales_gpu = cuda_make_array(l.scales, n);
+        l.scale_updates_gpu = cuda_make_array(l.scale_updates, n);
 
-    l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
-    l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+        l.delta_gpu = cuda_make_array(l.delta, l.batch*out_h*out_w*n);
+        l.output_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
 
-    if(binary){
-        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
-    }
-    if(xnor){
-        l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
-        l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
-    }
+        if(binary){
+            l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+        }
+        if(xnor){
+            l.binary_filters_gpu = cuda_make_array(l.filters, c*n*size*size);
+            l.binary_input_gpu = cuda_make_array(0, l.inputs*l.batch);
+        }
 
-    if(batch_normalize){
-        l.mean_gpu = cuda_make_array(l.mean, n);
-        l.variance_gpu = cuda_make_array(l.variance, n);
+        if(batch_normalize){
+            l.mean_gpu = cuda_make_array(l.mean, n);
+            l.variance_gpu = cuda_make_array(l.variance, n);
 
-        l.rolling_mean_gpu = cuda_make_array(l.mean, n);
-        l.rolling_variance_gpu = cuda_make_array(l.variance, n);
+            l.rolling_mean_gpu = cuda_make_array(l.mean, n);
+            l.rolling_variance_gpu = cuda_make_array(l.variance, n);
 
-        l.mean_delta_gpu = cuda_make_array(l.mean, n);
-        l.variance_delta_gpu = cuda_make_array(l.variance, n);
+            l.mean_delta_gpu = cuda_make_array(l.mean, n);
+            l.variance_delta_gpu = cuda_make_array(l.variance, n);
 
-        l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
-        l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
-    }
+            l.x_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+            l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
+        }
 #ifdef CUDNN
-    cudnnCreateTensorDescriptor(&l.srcTensorDesc);
-    cudnnCreateTensorDescriptor(&l.dstTensorDesc);
-    cudnnCreateFilterDescriptor(&l.filterDesc);
-    cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
-    cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
-    cudnnCreateFilterDescriptor(&l.dfilterDesc);
-    cudnnCreateConvolutionDescriptor(&l.convDesc);
-    cudnn_convolutional_setup(&l);
+        cudnnCreateTensorDescriptor(&l.srcTensorDesc);
+        cudnnCreateTensorDescriptor(&l.dstTensorDesc);
+        cudnnCreateFilterDescriptor(&l.filterDesc);
+        cudnnCreateTensorDescriptor(&l.dsrcTensorDesc);
+        cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
+        cudnnCreateFilterDescriptor(&l.dfilterDesc);
+        cudnnCreateConvolutionDescriptor(&l.convDesc);
+        cudnn_convolutional_setup(&l);
 #endif
+    }
 #endif
     l.workspace_size = get_workspace_size(l);
     l.activation = activation;

--
Gitblit v1.10.0