From e205c1e7aeb47e3dffd35d1b5ce7841d24b9aff4 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Thu, 04 Jan 2018 23:23:10 +0000
Subject: [PATCH] Tracking fixed

---
 src/gemm.c |   22 ++++++++++++++--------
 1 files changed, 14 insertions(+), 8 deletions(-)

diff --git a/src/gemm.c b/src/gemm.c
index 3003be0..c3154ec 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -151,14 +151,19 @@
             C[i*ldc + j] *= BETA;
         }
     }
-    if(!TA && !TB)
-        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
-    else if(TA && !TB)
-        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
-    else if(!TA && TB)
-        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
-    else
-        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
+
+	int t;
+	#pragma omp parallel for
+	for (t = 0; t < M; ++t) {
+		if (!TA && !TB)
+			gemm_nn(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
+		else if (TA && !TB)
+			gemm_tn(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
+		else if (!TA && TB)
+			gemm_nt(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
+		else
+			gemm_tt(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
+	}
 }
 
 #ifdef GPU
@@ -172,6 +177,7 @@
         float *C_gpu, int ldc)
 {
     cublasHandle_t handle = blas_handle();
+	cudaError_t stream_status = cublasSetStream(handle, get_cuda_stream());
     cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N), 
             (TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
     check_error(status);

--
Gitblit v1.10.0