From f92b20580a21663c5db9eb8608f8cabd7adbeb10 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Mon, 13 Aug 2018 22:51:31 +0000
Subject: [PATCH] Some fixes for AVX support on CPU

---
 src/convolutional_layer.c |    7 +++++--
 src/gemm.c                |   26 +++++++++++++++-----------
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 927eb99..e7c8e3f 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -621,7 +621,7 @@
     free(align_weights);
 }
 
-
+// further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin()
 size_t binary_transpose_align_input(int k, int n, float *b, char **t_bit_input, size_t ldb_align)
 {
     size_t new_ldb = k + (ldb_align - k%ldb_align); // (k / 8 + 1) * 8;
@@ -690,7 +690,8 @@
         //}
         //if (l.xnor && l.size == 3 && l.stride == 1 && l.pad == 1) {}
         //else
-            im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
+        // further optimizations: im2col_bin() for XNOR, and then transpose_aling_bin()
+        im2col_cpu_custom(state.input, l.c, l.h, l.w, l.size, l.stride, l.pad, b);
 
 
         //gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
@@ -793,6 +794,8 @@
                     //char *t_bit_input = calloc(new_ldb * n, sizeof(char));    // for im2col_cpu_custom_transpose() only
                     //float_to_bit(t_input, t_bit_input, new_ldb * n);    // for im2col_cpu_custom_transpose() only
 
+                    // 5x times faster than gemm()-float32
+                    // further optimizations: accelerate maxpool-layer with OpenMP/AVX
                     gemm_nn_custom_bin_mean_transposed(m, n, k, 1, l.align_bit_weights, new_ldb, t_bit_input, new_ldb, c, n, l.mean_arr);
 
                     //gemm_nn_custom_bin_mean_transposed(m, n, k, 1, bit_weights, k, t_bit_input, new_ldb, c, n, mean_arr);
diff --git a/src/gemm.c b/src/gemm.c
index b909814..d233e9c 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -674,6 +674,8 @@
         + _mm256_extract_epi64(val, 3);
 }
 
+// 5x times faster than gemm()-float32
+// further optimizations: do mean-mult only for the last layer
 void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
     unsigned char *A, int lda,
     unsigned char *B, int ldb,
@@ -873,7 +875,7 @@
     int channels_col = channels * ksize * ksize;
 
     // optimized version
-    if (height_col == height && width_col == width && stride == 1 && pad == 1)
+    if (height_col == height && width_col == width && stride == 1 && pad == 1 && is_fma_avx())
     {
         #pragma omp parallel for
         for (c = 0; c < channels_col; ++c) {
@@ -954,24 +956,26 @@
 
 void activate_array_cpu_custom(float *x, const int n, const ACTIVATION a)
 {
-    int i;
+    int i = 0;
     if (a == LINEAR)
     {}
     else if (a == LEAKY)
     {
-        __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
-        __m256 all256_01 = _mm256_set1_ps(0.1F);
+        if (is_fma_avx()) {
+            __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
+            __m256 all256_01 = _mm256_set1_ps(0.1F);
 
-        for (i = 0; i < n-8; i += 8) {
-            //x[i] = (x[i]>0) ? x[i] : .1*x[i];
+            for (i = 0; i < n - 8; i += 8) {
+                //x[i] = (x[i]>0) ? x[i] : .1*x[i];
 
-            __m256 src256 = _mm256_loadu_ps(&x[i]);
-            __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1
+                __m256 src256 = _mm256_loadu_ps(&x[i]);
+                __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1
 
-            __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats
+                __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats
 
-            __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult;
-            _mm256_storeu_ps(&x[i], result256);
+                __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult;
+                _mm256_storeu_ps(&x[i], result256);
+            }
         }
 
         for (; i < n; ++i) {

--
Gitblit v1.10.0