AlexeyAB
2017-10-21 ae74d0ef31485f84e1856b4733135d2753dbb033
src/gemm.c
@@ -5,6 +5,28 @@
#include <stdio.h>
#include <math.h>
void gemm_bin(int M, int N, int K, float ALPHA,
        char  *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            char A_PART = A[i*lda+k];
            if(A_PART){
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] += B[k*ldb+j];
                }
            } else {
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] -= B[k*ldb+j];
                }
            }
        }
    }
}
float *random_matrix(int rows, int cols)
{
    int i;
@@ -129,14 +151,19 @@
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
   int t;
   #pragma omp parallel for
   for (t = 0; t < M; ++t) {
      if (!TA && !TB)
         gemm_nn(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
      else if (TA && !TB)
         gemm_tn(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
      else if (!TA && TB)
         gemm_nt(1, N, K, ALPHA, A + t*lda, lda, B, ldb, C + t*ldc, ldc);
      else
         gemm_tt(1, N, K, ALPHA, A + t, lda, B, ldb, C + t*ldc, ldc);
   }
}
#ifdef GPU
@@ -151,7 +178,7 @@
{
    cublasHandle_t handle = blas_handle();
    cudaError_t status = cublasSgemm(handle, (TB ? CUBLAS_OP_T : CUBLAS_OP_N), 
                        (TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
            (TA ? CUBLAS_OP_T : CUBLAS_OP_N), N, M, K, &ALPHA, B_gpu, ldb, A_gpu, lda, &BETA, C_gpu, ldc);
    check_error(status);
}
@@ -276,6 +303,7 @@
int test_gpu_blas()
{
    /*
       test_gpu_accuracy(0,0,10,576,75); 
       test_gpu_accuracy(0,0,17,10,10); 
@@ -288,19 +316,28 @@
       test_gpu_accuracy(0,1,1000,10,100); 
       test_gpu_accuracy(1,1,1000,10,100); 
    test_gpu_accuracy(0,0,10,10,10);
       test_gpu_accuracy(0,0,10,10,10);
    time_ongpu(0,0,64,2916,363);
    time_ongpu(0,0,64,2916,363);
    time_ongpu(0,0,64,2916,363);
    time_ongpu(0,0,192,729,1600);
    time_ongpu(0,0,384,196,1728);
    time_ongpu(0,0,256,196,3456);
    time_ongpu(0,0,256,196,2304);
    time_ongpu(0,0,128,4096,12544);
    time_ongpu(0,0,128,4096,4096);
       time_ongpu(0,0,64,2916,363);
       time_ongpu(0,0,64,2916,363);
       time_ongpu(0,0,64,2916,363);
       time_ongpu(0,0,192,729,1600);
       time_ongpu(0,0,384,196,1728);
       time_ongpu(0,0,256,196,3456);
       time_ongpu(0,0,256,196,2304);
       time_ongpu(0,0,128,4096,12544);
       time_ongpu(0,0,128,4096,4096);
     */
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,75,12544);
    time_ongpu(0,0,64,576,12544);
    time_ongpu(0,0,256,2304,784);
    time_ongpu(1,1,2304,256,784);
    time_ongpu(0,0,512,4608,196);
    time_ongpu(1,1,4608,512,196);
return 0;
    return 0;
}
#endif