AlexeyAB
2018-03-28 6d56c38e8bcb9041335b03f27c192c24dfaedb1c
src/gemm.c
@@ -151,14 +151,19 @@
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
   int t;
   #pragma omp parallel for
   for (t = 0; t < M; ++t) {
      if (!TA && !TB)
         gemm_nn(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
      else if (TA && !TB)
         gemm_tn(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
      else if (!TA && TB)
         gemm_nt(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
      else
         gemm_tt(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
   }
}
#ifdef GPU
@@ -172,6 +177,7 @@
        float *C_gpu, int ldc)
{
    cublasHandle_t handle = blas_handle();
   cudaError_t stream_status = cublasSetStream(handle, get_cuda_stream());
    cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N), 
            (TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
    check_error(status);