From d6162af210d9d5648d33bf0fda40f773ac200df5 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Wed, 08 Aug 2018 23:31:36 +0000
Subject: [PATCH] Optimized on CPU: gemm_bin, im2col, activation, transpose

---
 src/convolutional_layer.c |  507 +++++++++++++++++++++++++++++++++++++++++++++++--------
 1 files changed, 431 insertions(+), 76 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 888eca3..a820588 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -8,6 +8,10 @@
 #include <stdio.h>
 #include <time.h>
 
+#ifdef CUDNN
+#pragma comment(lib, "cudnn.lib")
+#endif
+
 #ifdef AI2
 #include "xnor_layer.h"
 #endif
@@ -133,22 +137,71 @@
 
 #ifdef GPU
 #ifdef CUDNN
-void cudnn_convolutional_setup(layer *l)
+void cudnn_convolutional_setup(layer *l, int cudnn_preference)
 {
-    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
-    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
-    cudnnSetFilter4dDescriptor(l->dweightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
 
-    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->c, l->h, l->w); 
-    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w); 
-    cudnnSetFilter4dDescriptor(l->weightDesc, CUDNN_DATA_FLOAT, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size); 
-    cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);
+#ifdef CUDNN_HALF
+    // TRUE_HALF_CONFIG is only supported on architectures with true fp16 support (compute capability 5.3 and 6.0):
+    //   Tegra X1, Jetson TX1, DRIVE CX, DRIVE PX, Quadro GP100, Tesla P100
+    // PSEUDO_HALF_CONFIG is required for Tensor Cores - our case!
+    const cudnnDataType_t data_type = CUDNN_DATA_HALF;
+#else
+    cudnnDataType_t data_type = CUDNN_DATA_FLOAT;
+#endif
+
+#if(CUDNN_MAJOR >= 7)
+    // Tensor Core uses CUDNN_TENSOR_OP_MATH instead of CUDNN_DEFAULT_MATH
+    // For *_ALGO_WINOGRAD_NONFUSED can be used CUDNN_DATA_FLOAT
+    // otherwise Input, Filter and Output descriptors (xDesc, yDesc, wDesc, dxDesc, dyDesc and dwDesc as applicable) have dataType = CUDNN_DATA_HALF
+    // Three techniques for training using Mixed-precision: https://devblogs.nvidia.com/mixed-precision-training-deep-neural-networks/
+    // 1. Accumulation into FP32
+    // 2. Loss Scaling - required only for: activation gradients. We do not use.
+    // 3. FP32 Master Copy of Weights
+    // More: http://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#tensor_ops
+    cudnnSetConvolutionMathType(l->convDesc, CUDNN_TENSOR_OP_MATH);
+#endif
+
+    // INT8_CONFIG, INT8_EXT_CONFIG, INT8x4_CONFIG and INT8x4_EXT_CONFIG are only supported
+    //   on architectures with DP4A support (compute capability 6.1 and later).
+    //cudnnDataType_t data_type = CUDNN_DATA_INT8;
+
+    // backward delta
+    cudnnSetTensor4dDescriptor(l->dsrcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
+    cudnnSetTensor4dDescriptor(l->ddstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
+    cudnnSetFilter4dDescriptor(l->dweightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+
+    // forward
+    cudnnSetTensor4dDescriptor(l->srcTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->c, l->h, l->w);
+    cudnnSetTensor4dDescriptor(l->dstTensorDesc, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
+    cudnnSetFilter4dDescriptor(l->weightDesc, data_type, CUDNN_TENSOR_NCHW, l->n, l->c, l->size, l->size);
+
+    // batch norm
+    cudnnSetTensor4dDescriptor(l->normTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, 1, l->out_c, 1, 1);
+    cudnnSetTensor4dDescriptor(l->normDstTensorDesc, CUDNN_TENSOR_NCHW, CUDNN_DATA_FLOAT, l->batch, l->out_c, l->out_h, l->out_w);
+
+    cudnnSetTensor4dDescriptor(l->normDstTensorDescF16, CUDNN_TENSOR_NCHW, data_type, l->batch, l->out_c, l->out_h, l->out_w);
+#if(CUDNN_MAJOR >= 6)
+    cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION, CUDNN_DATA_FLOAT);    // cudnn >= 6.0
+#else
+    cudnnSetConvolution2dDescriptor(l->convDesc, l->pad, l->pad, l->stride, l->stride, 1, 1, CUDNN_CROSS_CORRELATION);    // cudnn 5.1
+#endif
+    int forward_algo = CUDNN_CONVOLUTION_FWD_PREFER_FASTEST;
+    int backward_algo = CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST;
+    int backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST;
+    if (cudnn_preference == cudnn_smallest)
+    {
+        forward_algo = CUDNN_CONVOLUTION_FWD_NO_WORKSPACE;
+        backward_algo = CUDNN_CONVOLUTION_BWD_DATA_NO_WORKSPACE;
+        backward_filter = CUDNN_CONVOLUTION_BWD_FILTER_NO_WORKSPACE;
+        printf(" CUDNN-slow ");
+    }
+
     cudnnGetConvolutionForwardAlgorithm(cudnn_handle(),
             l->srcTensorDesc,
             l->weightDesc,
             l->convDesc,
             l->dstTensorDesc,
-            CUDNN_CONVOLUTION_FWD_PREFER_FASTEST,
+            forward_algo,
             0,
             &l->fw_algo);
     cudnnGetConvolutionBackwardDataAlgorithm(cudnn_handle(),
@@ -156,7 +209,7 @@
             l->ddstTensorDesc,
             l->convDesc,
             l->dsrcTensorDesc,
-            CUDNN_CONVOLUTION_BWD_DATA_PREFER_FASTEST,
+            backward_algo,
             0,
             &l->bd_algo);
     cudnnGetConvolutionBackwardFilterAlgorithm(cudnn_handle(),
@@ -164,9 +217,41 @@
             l->ddstTensorDesc,
             l->convDesc,
             l->dweightDesc,
-            CUDNN_CONVOLUTION_BWD_FILTER_PREFER_FASTEST,
+            backward_filter,
             0,
             &l->bf_algo);
+
+    if (data_type == CUDNN_DATA_HALF)
+    {
+        // HALF-16 if(data_type == CUDNN_DATA_HALF)
+        l->fw_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
+        l->bd_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
+        l->bf_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
+
+        // FLOAT-32 if(data_type == CUDNN_DATA_FLOAT)
+        //l->fw_algo = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED;
+        //l->bd_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED;
+        //l->bf_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED;
+
+        int fw = 0, bd = 0, bf = 0;
+        if (l->fw_algo == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) fw = 1;
+            //printf("Tensor Cores - Forward enabled: l->fw_algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM \n");
+        if (l->fw_algo == CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED) fw = 2;
+            //printf("Tensor Cores - Forward enabled: l->fw_algo = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED \n");
+
+        if (l->bd_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_1) bd = 1;
+            //printf("Tensor Cores - Backward-data enabled: l->bd_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1  \n");
+        if (l->bd_algo == CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED) bd = 2;
+            //printf("Tensor Cores - Backward-data enabled: l->bd_algo = CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED \n");
+
+        if (l->bf_algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1) bf = 1;
+            //printf("Tensor Cores - Backward-filter enabled: l->bf_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1   \n");
+        if (l->bf_algo == CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED) bf = 2;
+            //printf("Tensor Cores - Backward-filter enabled: l->bf_algo = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED \n");
+
+        //if (fw == 2 && bd == 2 && bf == 2) printf("TF ");
+        //else if (fw == 1 && bd == 1 && bf == 1) printf("TH ");
+    }
 }
 #endif
 #endif
@@ -206,8 +291,8 @@
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = l.w * l.h * l.c;
 
-    l.output = calloc(l.batch*out_h * out_w * n, sizeof(float));
-    l.delta  = calloc(l.batch*out_h * out_w * n, sizeof(float));
+    l.output = calloc(l.batch*l.outputs, sizeof(float));
+    l.delta  = calloc(l.batch*l.outputs, sizeof(float));
 
     l.forward = forward_convolutional_layer;
     l.backward = backward_convolutional_layer;
@@ -232,8 +317,18 @@
         l.mean = calloc(n, sizeof(float));
         l.variance = calloc(n, sizeof(float));
 
+        l.mean_delta = calloc(n, sizeof(float));
+        l.variance_delta = calloc(n, sizeof(float));
+
         l.rolling_mean = calloc(n, sizeof(float));
         l.rolling_variance = calloc(n, sizeof(float));
+        l.x = calloc(l.batch*l.outputs, sizeof(float));
+        l.x_norm = calloc(l.batch*l.outputs, sizeof(float));
+    }
+    if(adam){
+        l.adam = 1;
+        l.m = calloc(c*n*size*size, sizeof(float));
+        l.v = calloc(c*n*size*size, sizeof(float));
     }
 
 #ifdef GPU
@@ -243,12 +338,15 @@
 
     if(gpu_index >= 0){
         if (adam) {
-            l.adam = 1;
-            l.m_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
-            l.v_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
+            l.m_gpu = cuda_make_array(l.m, c*n*size*size);
+            l.v_gpu = cuda_make_array(l.v, c*n*size*size);
         }
 
         l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
+#ifdef CUDNN_HALF
+        l.weights_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weights, c*n*size*size / 2);
+        l.weight_updates_gpu16 = cuda_make_array(NULL, c*n*size*size / 2); //cuda_make_array(l.weight_updates, c*n*size*size / 2);
+#endif
         l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
 
         l.biases_gpu = cuda_make_array(l.biases, n);
@@ -282,6 +380,9 @@
             l.x_norm_gpu = cuda_make_array(l.output, l.batch*out_h*out_w*n);
         }
 #ifdef CUDNN
+        cudnnCreateTensorDescriptor(&l.normDstTensorDesc);
+        cudnnCreateTensorDescriptor(&l.normDstTensorDescF16);
+        cudnnCreateTensorDescriptor(&l.normTensorDesc);
         cudnnCreateTensorDescriptor(&l.srcTensorDesc);
         cudnnCreateTensorDescriptor(&l.dstTensorDesc);
         cudnnCreateFilterDescriptor(&l.weightDesc);
@@ -289,14 +390,16 @@
         cudnnCreateTensorDescriptor(&l.ddstTensorDesc);
         cudnnCreateFilterDescriptor(&l.dweightDesc);
         cudnnCreateConvolutionDescriptor(&l.convDesc);
-        cudnn_convolutional_setup(&l);
+        cudnn_convolutional_setup(&l, cudnn_fastest);
 #endif
     }
 #endif
     l.workspace_size = get_workspace_size(l);
     l.activation = activation;
 
-    fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
+    //fprintf(stderr, "conv  %5d %2d x%2d /%2d  %4d x%4d x%4d   ->  %4d x%4d x%4d\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c);
+    l.bflops = (2.0 * l.n * l.size*l.size*l.c * l.out_h*l.out_w) / 1000000000.;
+    fprintf(stderr, "conv  %5d %2d x%2d /%2d  %4d x%4d x%4d   ->  %4d x%4d x%4d %5.3f BF\n", n, size, size, stride, w, h, c, l.out_w, l.out_h, l.out_c, l.bflops);
 
     return l;
 }
@@ -342,6 +445,8 @@
 
 void resize_convolutional_layer(convolutional_layer *l, int w, int h)
 {
+    int old_w = l->w;
+    int old_h = l->h;
     l->w = w;
     l->h = h;
     int out_w = convolutional_out_width(*l);
@@ -353,22 +458,55 @@
     l->outputs = l->out_h * l->out_w * l->out_c;
     l->inputs = l->w * l->h * l->c;
 
-    l->output = realloc(l->output,
-            l->batch*out_h * out_w * l->n*sizeof(float));
-    l->delta  = realloc(l->delta,
-            l->batch*out_h * out_w * l->n*sizeof(float));
+    l->output = realloc(l->output, l->batch*l->outputs*sizeof(float));
+    l->delta  = realloc(l->delta,  l->batch*l->outputs*sizeof(float));
+    if(l->batch_normalize){
+        l->x = realloc(l->x, l->batch*l->outputs*sizeof(float));
+        l->x_norm  = realloc(l->x_norm, l->batch*l->outputs*sizeof(float));
+    }
+
+    if (l->xnor) {
+        //l->binary_input = realloc(l->inputs*l->batch, sizeof(float));
+    }
 
 #ifdef GPU
-    cuda_free(l->delta_gpu);
-    cuda_free(l->output_gpu);
+    if (old_w < w || old_h < h) {
+        cuda_free(l->delta_gpu);
+        cuda_free(l->output_gpu);
 
-    l->delta_gpu =     cuda_make_array(l->delta, l->batch*out_h*out_w*l->n);
-    l->output_gpu =    cuda_make_array(l->output, l->batch*out_h*out_w*l->n);
+        l->delta_gpu = cuda_make_array(l->delta, l->batch*l->outputs);
+        l->output_gpu = cuda_make_array(l->output, l->batch*l->outputs);
+
+        if (l->batch_normalize) {
+            cuda_free(l->x_gpu);
+            cuda_free(l->x_norm_gpu);
+
+            l->x_gpu = cuda_make_array(l->output, l->batch*l->outputs);
+            l->x_norm_gpu = cuda_make_array(l->output, l->batch*l->outputs);
+        }
+
+        if (l->xnor) {
+            cuda_free(l->binary_input_gpu);
+            l->binary_input_gpu = cuda_make_array(0, l->inputs*l->batch);
+        }
+    }
 #ifdef CUDNN
-    cudnn_convolutional_setup(l);
+    cudnn_convolutional_setup(l, cudnn_fastest);
 #endif
 #endif
     l->workspace_size = get_workspace_size(*l);
+
+#ifdef CUDNN
+    // check for excessive memory consumption
+    size_t free_byte;
+    size_t total_byte;
+    check_error(cudaMemGetInfo(&free_byte, &total_byte));
+    if (l->workspace_size > free_byte || l->workspace_size >= total_byte / 2) {
+        printf(" used slow CUDNN algo without Workspace! Need memory: %zu, available: %zu\n", l->workspace_size, (free_byte < total_byte/2) ? free_byte : total_byte/2);
+        cudnn_convolutional_setup(l, cudnn_smallest);
+        l->workspace_size = get_workspace_size(*l);
+    }
+#endif
 }
 
 void add_bias(float *output, float *biases, int batch, int n, int size)
@@ -405,49 +543,148 @@
     }
 }
 
+void gemm_nn_custom(int M, int N, int K, float ALPHA,
+    float *A, int lda,
+    float *B, int ldb,
+    float *C, int ldc)
+{
+    int i, j, k;
+    for (i = 0; i < M; ++i) {
+        for (k = 0; k < K; ++k) {
+            register float A_PART = ALPHA*A[i*lda + k];
+            //printf("\n weight = %f \n", A_PART);
+            for (j = 0; j < N; ++j) {
+                C[i*ldc + j] += A_PART*B[k*ldb + j];
+            }
+        }
+    }
+}
+
+
+void get_mean_array(float *src, size_t size, size_t filters, float *mean_arr) {
+    size_t i, counter;
+    counter = 0;
+    for (i = 0; i < size; i += size / filters) {
+        mean_arr[counter++] = fabs(src[i]);
+    }
+}
+
+/*
+void float_to_bit(float *src, unsigned char *dst, size_t size) {
+
+    size_t dst_size = size / 8 + 1;
+    memset(dst, 0, dst_size);
+    size_t i, dst_i, dst_shift;
+    for (i = 0; i < size; ++i) {
+        if (src[i] > 0) set_bit(dst, i);
+    }
+}
+*/
+
+void bit_to_float(unsigned char *src, float *dst, size_t size, size_t filters, float *mean_arr) {
+    memset(dst, 0, size *sizeof(float));
+    size_t i,  src_i, src_shift;
+
+    for (i = 0; i < size; ++i) {
+        float mean_val = 1;
+        if(mean_arr != NULL) mean_val = fabs(mean_arr[i / (size / filters)]);
+        if(get_bit(src, i)) dst[i] = mean_val;
+        else dst[i] = -mean_val;
+    }
+}
+
+void binary_align_weights(convolutional_layer *l, size_t lda_align)
+{
+    int m = l->n;
+    int k = l->size*l->size*l->c;
+    size_t new_lda = k + (lda_align - k%lda_align); // (k / 8 + 1) * 8;
+
+    binarize_weights(l->weights, m, k, l->binary_weights);
+
+    size_t align_weights_size = new_lda * m;
+    size_t align_bit_weights_size = align_weights_size / 8;// +1;
+    float *align_weights = calloc(align_weights_size, sizeof(float));
+    l->align_bit_weights = calloc(align_bit_weights_size, sizeof(char));
+
+    size_t i, j;
+    // align A without transpose
+    for (i = 0; i < m; ++i) {
+        for (j = 0; j < k; ++j) {
+            align_weights[i*new_lda + j] = l->binary_weights[i*k + j];
+        }
+    }
+    float_to_bit(align_weights, l->align_bit_weights, align_weights_size);
+
+    l->mean_arr = calloc(l->n, sizeof(float));
+    get_mean_array(align_weights, align_weights_size, l->n, l->mean_arr);
+
+    free(align_weights);
+}
+
+
+size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align)
+{
+    size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
+    size_t t_intput_size = new_ldb * n;
+    size_t t_bit_input_size = t_intput_size / 8;// +1;
+    float *t_input = calloc(t_intput_size, sizeof(float));
+    //char *
+    *t_bit_input = calloc(t_bit_input_size, sizeof(char));
+
+    //printf("\n bit_input_size = %d, n = %d, k = %d, ldb = %d \n", bit_input_size, n, k, n);
+    //printf("\n t_bit_input_size = %d, k = %d, n = %d, new_ldb = %d \n", t_bit_input_size, k, n, new_ldb);
+
+    //printf("\n align_weights_size = %d, k = %d, m = %d, lda = %d \n", align_weights_size, k, m, k);
+    //printf("\n align_bit_weights_size = %d, k = %d, m = %d, new_lda = %d \n", align_bit_weights_size, k, m, new_ldb);
+
+    // transpose and align B
+    int i, j;
+    //#pragma omp parallel for
+    /*
+    for (i = 0; i < n; ++i) {
+        for (j = 0; j < k; ++j) {
+            t_input[i*new_ldb + j] = b[j*n + i];
+        }
+    }*/
+    //transpose_block_SSE4x4(float *A, float *B, const int n, const int m, const int lda, const int ldb, const int block_size)
+
+    //transpose_block(b, t_input, k, n, n, new_ldb, 16);
+
+    int blocksize = 1;
+    int mod_k = 1, mod_n = 1;
+    for (i = 2; i < 256; i *= 2)
+        if (k % i == 0) mod_k = i;
+
+    for (i = 2; i < 256; i *= 2)
+        if (n % i == 0) mod_n = i;
+
+    blocksize = (mod_k < mod_n) ? mod_k : mod_n;
+
+    transpose_block_SSE4x4(b, t_input, k, n, n, new_ldb, blocksize);
+
+    //transpose_block(b, t_input, k, n, n, new_ldb, blocksize);
+    //printf("\n blocksize = %d \n", blocksize);
+
+    float_to_bit(t_input, *t_bit_input, t_intput_size);
+    free(t_input);
+
+    return t_intput_size;
+}
+
+
 void forward_convolutional_layer(convolutional_layer l, network_state state)
 {
     int out_h = convolutional_out_height(l);
     int out_w = convolutional_out_width(l);
     int i;
 
-
     fill_cpu(l.outputs*l.batch, 0, l.output, 1);
 
-    /*
-       if(l.binary){
-       binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
-       binarize_weights2(l.weights, l.n, l.c*l.size*l.size, l.cweights, l.scales);
-       swap_binary(&l);
-       }
-     */
-
-    /*
-       if(l.binary){
-       int m = l.n;
-       int k = l.size*l.size*l.c;
-       int n = out_h*out_w;
-
-       char  *a = l.cweights;
-       float *b = state.workspace;
-       float *c = l.output;
-
-       for(i = 0; i < l.batch; ++i){
-       im2col_cpu(state.input, l.c, l.h, l.w, 
-       l.size, l.stride, l.pad, b);
-       gemm_bin(m,n,k,1,a,k,b,n,c,n);
-       c += n*m;
-       state.input += l.c*l.h*l.w;
-       }
-       scale_bias(l.output, l.scales, l.batch, l.n, out_h*out_w);
-       add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
-       activate_array(l.output, m*n*l.batch, l.activation);
-       return;
-       }
-     */
-
     if(l.xnor){
-        binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
+        if (!l.align_bit_weights) {
+            binarize_weights(l.weights, l.n, l.c*l.size*l.size, l.binary_weights);
+            //printf("\n binarize_weights l.align_bit_weights = %p \n", l.align_bit_weights);
+        }
         swap_binary(&l);
         binarize_cpu(state.input, l.c*l.h*l.w*l.batch, l.binary_input);
         state.input = l.binary_input;
@@ -457,22 +694,134 @@
     int k = l.size*l.size*l.c;
     int n = out_h*out_w;
 
-    if (l.xnor && l.c%32 == 0 && AI2) {
-        forward_xnor_layer(l, state);
-        printf("xnor\n");
-    } else {
+    float *a = l.weights;
+    float *b = state.workspace;
+    float *c = l.output;
 
-        float *a = l.weights;
-        float *b = state.workspace;
-        float *c = l.output;
+    static int u = 0;
+    u++;
 
-        for(i = 0; i < l.batch; ++i){
-            im2col_cpu(state.input, l.c, l.h, l.w, 
-                    l.size, l.stride, l.pad, b);
-            gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-            c += n*m;
-            state.input += l.c*l.h*l.w;
+    for(i = 0; i < l.batch; ++i){
+        //im2col_cpu(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
+        im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
+
+        //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
+        //gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
+        if (l.xnor) {
+            size_t output_size = l.outputs;
+            //float *count_output = calloc(output_size, sizeof(float));
+            //size_t bit_output_size = output_size / 8 + 1;
+            //char *bit_output = calloc(bit_output_size, sizeof(char));
+
+            size_t intput_size = n * k; // (out_h*out_w) X (l.size*l.size*l.c) : after im2col()
+            size_t bit_input_size = intput_size / 8 + 1;
+            //char *bit_input = calloc(bit_input_size, sizeof(char));
+
+            size_t weights_size = k * m; //l.size*l.size*l.c*l.n;
+            size_t bit_weights_size = weights_size / 8 + 1;
+            //char *bit_weights = calloc(bit_weights_size, sizeof(char));
+            //float *mean_arr = calloc(l.n, sizeof(float));
+
+            // test: float->bit->float
+            //get_mean_array(l.weights, weights_size, l.n, mean_arr);
+            //float_to_bit(l.weights, bit_weights, weights_size);
+            //memset(l.weights, 0, weights_size * sizeof(float));
+            //bit_to_float(bit_weights, l.weights, weights_size, l.n, mean_arr); // just for test float->bit->float
+
+            //float_to_bit(b, bit_input, intput_size);
+            //memset(b, 0, intput_size * sizeof(float));
+            //bit_to_float(bit_input, b, intput_size, 1, NULL); // just for test float->bit->float
+
+            // transpose B from NxK to KxN (x-axis (ldb = l.size*l.size*l.c) - should be multiple of 8 bits)
+            {
+                /*
+                size_t ldb_align = 256;// 8;
+
+                size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
+                size_t t_intput_size = new_ldb * n;
+                size_t t_bit_input_size = t_intput_size / 8;// +1;
+                float *t_input = calloc(t_intput_size, sizeof(float));
+                char *t_bit_input = calloc(t_bit_input_size, sizeof(char));
+
+                //printf("\n bit_input_size = %d, n = %d, k = %d, ldb = %d \n", bit_input_size, n, k, n);
+                //printf("\n t_bit_input_size = %d, k = %d, n = %d, new_ldb = %d \n", t_bit_input_size, k, n, new_ldb);
+
+
+                //printf("\n align_weights_size = %d, k = %d, m = %d, lda = %d \n", align_weights_size, k, m, k);
+                //printf("\n align_bit_weights_size = %d, k = %d, m = %d, new_lda = %d \n", align_bit_weights_size, k, m, new_ldb);
+
+
+                // transpose and align B
+                int i, j;
+                for (i = 0; i < n; ++i) {
+                    for (j = 0; j < k; ++j) {
+                        t_input[i*new_ldb + j] = b[j*n + i];
+                    }
+                }
+                float_to_bit(t_input, t_bit_input, t_intput_size);
+
+
+
+                if (!l.align_bit_weights)
+                {
+                    size_t align_weights_size = new_ldb * m;
+                    size_t align_bit_weights_size = align_weights_size / 8;// +1;
+                    float *align_weights = calloc(align_weights_size, sizeof(float));
+                    l.align_bit_weights = calloc(align_bit_weights_size, sizeof(char));
+
+                    // align A without transpose
+                    for (i = 0; i < m; ++i) {
+                        for (j = 0; j < k; ++j) {
+                            align_weights[i*new_ldb + j] = a[i*k + j];
+                        }
+                    }
+                    float_to_bit(align_weights, l.align_bit_weights, align_weights_size);
+
+                    l.mean_arr = calloc(l.n, sizeof(float));
+                    get_mean_array(align_weights, align_weights_size, l.n, l.mean_arr);
+
+                    free(align_weights);
+                }
+                */
+                size_t ldb_align = 256; // 256 bit for AVX2
+                size_t new_ldb = k + (ldb_align - k%ldb_align);
+                char *t_bit_input = NULL;
+                size_t t_intput_size = binary_transpose_align_input(k, n, b, &t_bit_input, ldb_align);
+
+                gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr);
+
+                //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);
+
+                //free(t_input);
+                free(t_bit_input);
+
+                //free(align_bit_weights);
+            }
+
+            // for bit_input: (k * n)
+            //if (u == 8) gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, mean_arr);  // last xnor layer
+            //else gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, NULL);
+
+            //gemm_nn_custom_bin_mean(m, n, k, 1, bit_weights, k, bit_input, n, c, n, mean_arr);
+
+            //printf("\n u = %d \n", u);
+
+            //gemm_nn_custom(m, n, k, 1, a, k, b, n, c, n);
+
+            //int j;
+            //if (u != 8) for (j = 0; j < l.n; ++j) l.biases[j] = l.biases[j] / (mean_arr[j]*2);
+
+            //free(count_output);
+            //free(bit_input);
+            //free(bit_weights);
+            //free(mean_arr);
         }
+        else {
+            gemm(0, 0, m, n, k, 1, a, k, b, n, 1, c, n);
+            // bit-count to float
+        }
+        c += n*m;
+        state.input += l.c*l.h*l.w;
     }
 
     if(l.batch_normalize){
@@ -480,7 +829,9 @@
     }
     add_bias(l.output, l.biases, l.batch, l.n, out_h*out_w);
 
-    activate_array(l.output, m*n*l.batch, l.activation);
+    //activate_array(l.output, m*n*l.batch, l.activation);
+    activate_array_cpu_custom(l.output, m*n*l.batch, l.activation);
+
     if(l.binary || l.xnor) swap_binary(&l);
 }
 
@@ -495,6 +846,10 @@
     gradient_array(l.output, m*k*l.batch, l.activation, l.delta);
     backward_bias(l.bias_updates, l.delta, l.batch, l.n, k);
 
+    if(l.batch_normalize){
+        backward_batchnorm_layer(l, state);
+    }
+
     for(i = 0; i < l.batch; ++i){
         float *a = l.delta + i*m*k;
         float *b = state.workspace;
@@ -502,7 +857,7 @@
 
         float *im = state.input+i*l.c*l.h*l.w;
 
-        im2col_cpu(im, l.c, l.h, l.w, 
+        im2col_cpu(im, l.c, l.h, l.w,
                 l.size, l.stride, l.pad, b);
         gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
 

--
Gitblit v1.10.0