From f92b20580a21663c5db9eb8608f8cabd7adbeb10 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Mon, 13 Aug 2018 22:51:31 +0000
Subject: [PATCH] Some fixes for AVX support on CPU
---
src/gemm.c | 26 +++++++++++++++-----------
1 files changed, 15 insertions(+), 11 deletions(-)
diff --git a/src/gemm.c b/src/gemm.c
index b909814..d233e9c 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -674,6 +674,8 @@
+ _mm256_extract_epi64(val, 3);
}
+// 5x times faster than gemm()-float32
+// further optimizations: do mean-mult only for the last layer
void gemm_nn_custom_bin_mean_transposed(int M, int N, int K, float ALPHA_UNUSED,
unsigned char *A, int lda,
unsigned char *B, int ldb,
@@ -873,7 +875,7 @@
int channels_col = channels * ksize * ksize;
// optimized version
- if (height_col == height && width_col == width && stride == 1 && pad == 1)
+ if (height_col == height && width_col == width && stride == 1 && pad == 1 && is_fma_avx())
{
#pragma omp parallel for
for (c = 0; c < channels_col; ++c) {
@@ -954,24 +956,26 @@
void activate_array_cpu_custom(float *x, const int n, const ACTIVATION a)
{
- int i;
+ int i = 0;
if (a == LINEAR)
{}
else if (a == LEAKY)
{
- __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
- __m256 all256_01 = _mm256_set1_ps(0.1F);
+ if (is_fma_avx()) {
+ __m256i all256_sing1 = _mm256_set_epi32(0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000, 0x80000000);
+ __m256 all256_01 = _mm256_set1_ps(0.1F);
- for (i = 0; i < n-8; i += 8) {
- //x[i] = (x[i]>0) ? x[i] : .1*x[i];
+ for (i = 0; i < n - 8; i += 8) {
+ //x[i] = (x[i]>0) ? x[i] : .1*x[i];
- __m256 src256 = _mm256_loadu_ps(&x[i]);
- __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1
+ __m256 src256 = _mm256_loadu_ps(&x[i]);
+ __m256 mult256 = _mm256_mul_ps((src256), all256_01); // mult * 0.1
- __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats
+ __m256i sign256 = _mm256_and_si256(_mm256_castps_si256(src256), all256_sing1); // check sign in 8 x 32-bit floats
- __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult;
- _mm256_storeu_ps(&x[i], result256);
+ __m256 result256 = _mm256_blendv_ps(src256, mult256, _mm256_castsi256_ps(sign256)); // (sign>0) ? src : mult;
+ _mm256_storeu_ps(&x[i], result256);
+ }
}
for (; i < n; ++i) {
--
Gitblit v1.10.0