From b5938098d12d5c9fe48e7cc71ae3d75b7306833f Mon Sep 17 00:00:00 2001
From: Alexey <AlexeyAB@users.noreply.github.com>
Date: Mon, 02 Jan 2017 12:33:31 +0000
Subject: [PATCH] Update Readme.md - pragma-libs in How to compile

---
 src/convolutional_kernels.cu |  288 +++++++++++++++++++++++++++++++++------------------------
 1 files changed, 166 insertions(+), 122 deletions(-)

diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 5f24ca5..ae9df8f 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -4,6 +4,7 @@
 
 extern "C" {
 #include "convolutional_layer.h"
+#include "batchnorm_layer.h"
 #include "gemm.h"
 #include "blas.h"
 #include "im2col.h"
@@ -12,219 +13,262 @@
 #include "cuda.h"
 }
 
-__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
+__global__ void binarize_kernel(float *x, int n, float *binary)
 {
-    int offset = blockIdx.x * blockDim.x + threadIdx.x;
-    int filter = blockIdx.y;
-    int batch = blockIdx.z;
-
-    if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
+    int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (i >= n) return;
+    binary[i] = (x[i] >= 0) ? 1 : -1;
 }
 
-void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
+void binarize_gpu(float *x, int n, float *binary)
 {
-    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
-    dim3 dimBlock(BLOCK, 1, 1);
-
-    scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+    binarize_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, binary);
     check_error(cudaPeekAtLastError());
 }
 
-__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+__global__ void binarize_input_kernel(float *input, int n, int size, float *binary)
 {
-    __shared__ float part[BLOCK];
-    int i,b;
-    int filter = blockIdx.x;
-    int p = threadIdx.x;
-    float sum = 0;
-    for(b = 0; b < batch; ++b){
-        for(i = 0; i < size; i += BLOCK){
-            int index = p + i + size*(filter + n*b);
-            sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
-        }
+    int s = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (s >= size) return;
+    int i = 0;
+    float mean = 0;
+    for(i = 0; i < n; ++i){
+        mean += abs(input[i*size + s]);
     }
-    part[p] = sum;
-    __syncthreads();
-    if (p == 0) {
-        for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+    mean = mean / n;
+    for(i = 0; i < n; ++i){
+        binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
     }
 }
 
-void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+void binarize_input_gpu(float *input, int n, int size, float *binary)
 {
-    backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+    binarize_input_kernel<<<cuda_gridsize(size), BLOCK>>>(input, n, size, binary);
     check_error(cudaPeekAtLastError());
 }
 
-__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+
+__global__ void binarize_weights_kernel(float *weights, int n, int size, float *binary)
 {
-    int offset = blockIdx.x * blockDim.x + threadIdx.x;
-    int filter = blockIdx.y;
-    int batch = blockIdx.z;
-
-    if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
-}
-
-void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
-{
-    dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
-    dim3 dimBlock(BLOCK, 1, 1);
-
-    add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
-    check_error(cudaPeekAtLastError());
-}
-
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
-{
-    __shared__ float part[BLOCK];
-    int i,b;
-    int filter = blockIdx.x;
-    int p = threadIdx.x;
-    float sum = 0;
-    for(b = 0; b < batch; ++b){
-        for(i = 0; i < size; i += BLOCK){
-            int index = p + i + size*(filter + n*b);
-            sum += (p+i < size) ? delta[index] : 0;
-        }
+    int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+    if (f >= n) return;
+    int i = 0;
+    float mean = 0;
+    for(i = 0; i < size; ++i){
+        mean += abs(weights[f*size + i]);
     }
-    part[p] = sum;
-    __syncthreads();
-    if (p == 0) {
-        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
+    mean = mean / size;
+    for(i = 0; i < size; ++i){
+        binary[f*size + i] = (weights[f*size + i] > 0) ? mean : -mean;
+        //binary[f*size + i] = weights[f*size + i];
     }
 }
 
-void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+void binarize_weights_gpu(float *weights, int n, int size, float *binary)
 {
-    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
+    binarize_weights_kernel<<<cuda_gridsize(n), BLOCK>>>(weights, n, size, binary);
     check_error(cudaPeekAtLastError());
 }
 
 void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
+    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
+    if(l.binary){
+        binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
+        swap_binary(&l);
+    }
+
+    if(l.xnor){
+        binarize_weights_gpu(l.weights_gpu, l.n, l.c*l.size*l.size, l.binary_weights_gpu);
+        swap_binary(&l);
+        binarize_gpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input_gpu);
+        state.input = l.binary_input_gpu;
+    }
+
+#ifdef CUDNN
+    float one = 1;
+    cudnnConvolutionForward(cudnn_handle(),
+                &one,
+                l.srcTensorDesc,
+                state.input,
+                l.weightDesc,
+                l.weights_gpu,
+                l.convDesc,
+                l.fw_algo,
+                state.workspace,
+                l.workspace_size,
+                &one,
+                l.dstTensorDesc,
+                l.output_gpu);
+
+#else
     int i;
     int m = l.n;
     int k = l.size*l.size*l.c;
-    int n = convolutional_out_height(l)*
-        convolutional_out_width(l);
-
-    fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
+    int n = l.out_w*l.out_h;
     for(i = 0; i < l.batch; ++i){
-        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
-        float * a = l.filters_gpu;
-        float * b = l.col_image_gpu;
+        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
+        float * a = l.weights_gpu;
+        float * b = state.workspace;
         float * c = l.output_gpu;
         gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
     }
+#endif
 
-    if(l.batch_normalize){
-        if(state.train){
-            fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_mean_gpu, l.mean_gpu);   
-            fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.spatial_variance_gpu, l.variance_gpu);
-
-            scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
-            axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
-            scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
-            axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
-
-            // cuda_pull_array(l.variance_gpu, l.mean, l.n);
-            // printf("%f\n", l.mean[0]);
-
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
-            normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
-            copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
-        } else {
-            normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
-        }
-
-        scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
+    if (l.batch_normalize) {
+        forward_batchnorm_layer_gpu(l, state);
     }
-    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
+    add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, l.out_w*l.out_h);
 
-    activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
+    activate_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation);
+    //if(l.dot > 0) dot_error_gpu(l);
+    if(l.binary || l.xnor) swap_binary(&l);
 }
 
 void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
 {
-    int i;
-    int m = l.n;
-    int n = l.size*l.size*l.c;
-    int k = convolutional_out_height(l)*
-        convolutional_out_width(l);
+    gradient_array_ongpu(l.output_gpu, l.outputs*l.batch, l.activation, l.delta_gpu);
 
-    gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
-
-    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
+    backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h);
 
     if(l.batch_normalize){
-        backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
+        backward_batchnorm_layer_gpu(l, state);
+        //axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.x_gpu, 1, l.delta_gpu, 1);
+    } else {
+        //axpy_ongpu(l.outputs*l.batch, -state.net.decay, l.output_gpu, 1, l.delta_gpu, 1);
+    }
+    float *original_input = state.input;
 
-        scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
+    if(l.xnor) state.input = l.binary_input_gpu;
+#ifdef CUDNN
+    float one = 1;
+    cudnnConvolutionBackwardFilter(cudnn_handle(),
+            &one,
+            l.srcTensorDesc,
+            state.input,
+            l.ddstTensorDesc,
+            l.delta_gpu,
+            l.convDesc,
+            l.bf_algo,
+            state.workspace,
+            l.workspace_size,
+            &one,
+            l.dweightDesc,
+            l.weight_updates_gpu);
 
-        fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_mean_delta_gpu, l.mean_delta_gpu);
-        fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.spatial_variance_delta_gpu, l.variance_delta_gpu);
-        normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
+    if(state.delta){
+        if(l.binary || l.xnor) swap_binary(&l);
+        cudnnConvolutionBackwardData(cudnn_handle(),
+                &one,
+                l.weightDesc,
+                l.weights_gpu,
+                l.ddstTensorDesc,
+                l.delta_gpu,
+                l.convDesc,
+                l.bd_algo,
+                state.workspace,
+                l.workspace_size,
+                &one,
+                l.dsrcTensorDesc,
+                state.delta);
+        if(l.binary || l.xnor) swap_binary(&l);
+        if(l.xnor) gradient_array_ongpu(original_input, l.batch*l.c*l.h*l.w, HARDTAN, state.delta);
     }
 
+#else
+    int m = l.n;
+    int n = l.size*l.size*l.c;
+    int k = l.out_w*l.out_h;
+
+    int i;
     for(i = 0; i < l.batch; ++i){
         float * a = l.delta_gpu;
-        float * b = l.col_image_gpu;
-        float * c = l.filter_updates_gpu;
+        float * b = state.workspace;
+        float * c = l.weight_updates_gpu;
 
-        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, l.col_image_gpu);
+        im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.workspace);
         gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
 
         if(state.delta){
-            float * a = l.filters_gpu;
+            if(l.binary || l.xnor) swap_binary(&l);
+            float * a = l.weights_gpu;
             float * b = l.delta_gpu;
-            float * c = l.col_image_gpu;
+            float * c = state.workspace;
 
             gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
 
-            col2im_ongpu(l.col_image_gpu, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
+            col2im_ongpu(state.workspace, l.c,  l.h,  l.w,  l.size,  l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
+            if(l.binary || l.xnor) {
+                swap_binary(&l);
+            }
+            if(l.xnor) gradient_array_ongpu(original_input + i*l.c*l.h*l.w, l.c*l.h*l.w, HARDTAN, state.delta + i*l.c*l.h*l.w);
         }
     }
+#endif
 }
 
 void pull_convolutional_layer(convolutional_layer layer)
 {
-    cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
+    cuda_pull_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
     cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
-    cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
+    cuda_pull_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
     cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
     if (layer.batch_normalize){
         cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
         cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
         cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
     }
+    if (layer.adam){
+        cuda_pull_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size);
+        cuda_pull_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size);
+    }
 }
 
 void push_convolutional_layer(convolutional_layer layer)
 {
-    cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
+    cuda_push_array(layer.weights_gpu, layer.weights, layer.c*layer.n*layer.size*layer.size);
     cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
-    cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
+    cuda_push_array(layer.weight_updates_gpu, layer.weight_updates, layer.c*layer.n*layer.size*layer.size);
     cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
     if (layer.batch_normalize){
         cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
         cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
         cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
     }
+    if (layer.adam){
+        cuda_push_array(layer.m_gpu, layer.m, layer.c*layer.n*layer.size*layer.size);
+        cuda_push_array(layer.v_gpu, layer.v, layer.c*layer.n*layer.size*layer.size);
+    }
 }
 
 void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
 {
     int size = layer.size*layer.size*layer.c*layer.n;
-
     axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
     scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
 
-    axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
-    scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
+    if(layer.scales_gpu){
+        axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
+        scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
+    }
 
-    axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
-    axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
-    scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
+    if(layer.adam){
+        scal_ongpu(size, layer.B1, layer.m_gpu, 1);
+        scal_ongpu(size, layer.B2, layer.v_gpu, 1);
+
+        axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
+
+        axpy_ongpu(size, -(1-layer.B1), layer.weight_updates_gpu, 1, layer.m_gpu, 1);
+        mul_ongpu(size, layer.weight_updates_gpu, 1, layer.weight_updates_gpu, 1);
+        axpy_ongpu(size, (1-layer.B2), layer.weight_updates_gpu, 1, layer.v_gpu, 1);
+
+        adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
+        fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
+    }else{
+        axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
+        axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
+        scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
+    }
 }
 
 

--
Gitblit v1.10.0