From edbccdfcaf46f11e631afe98796f3e6e170da5d0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Sun, 26 Oct 2014 05:04:34 +0000
Subject: [PATCH] Maybe something changed?

---
 src/gemm.c |  196 ++++++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 177 insertions(+), 19 deletions(-)

diff --git a/src/gemm.c b/src/gemm.c
index 1a7bcdd..63c2950 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -1,4 +1,5 @@
 #include "mini_blas.h"
+#include "utils.h"
 
 void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
         float *A, int lda, 
@@ -6,11 +7,7 @@
         float BETA,
         float *C, int ldc)
 {
-#ifdef GPU
-    gemm_gpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
-#else
     gemm_cpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
-#endif
 }
 
 void gemm_nn(int M, int N, int K, float ALPHA, 
@@ -39,7 +36,7 @@
         for(j = 0; j < N; ++j){
             register float sum = 0;
             for(k = 0; k < K; ++k){
-                sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
+                sum += ALPHA*A[i*lda+k]*B[j*ldb + k];
             }
             C[i*ldc+j] += sum;
         }
@@ -61,6 +58,7 @@
         }
     }
 }
+
 void gemm_tt(int M, int N, int K, float ALPHA, 
         float *A, int lda, 
         float *B, int ldb,
@@ -69,9 +67,11 @@
     int i,j,k;
     for(i = 0; i < M; ++i){
         for(j = 0; j < N; ++j){
+            register float sum = 0;
             for(k = 0; k < K; ++k){
-                C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb];
+                sum += ALPHA*A[i+k*lda]*B[k+j*ldb];
             }
+            C[i*ldc+j] += sum;
         }
     }
 }
@@ -83,6 +83,7 @@
         float BETA,
         float *C, int ldc)
 {
+    //printf("cpu: %d %d %d %d %d %f %d %d %f %d\n",TA, TB, M, N, K, ALPHA, lda, ldb, BETA, ldc);
     int i, j;
     for(i = 0; i < M; ++i){
         for(j = 0; j < N; ++j){
@@ -103,11 +104,16 @@
 
 #include "opencl.h"
 #include <math.h>
+#include <clBLAS.h>
 
 #define STR_HELPER(x) #x
 #define STR(x) STR_HELPER(x)
 
-#define BLOCK 8
+#ifdef __APPLE__
+#define BLOCK 1
+#else
+#define BLOCK 16
+#endif
 
 cl_kernel get_gemm_kernel()
 {
@@ -120,12 +126,110 @@
     return gemm_kernel;
 }
 
+cl_kernel get_gemm_nt_kernel()
+{
+    static int init = 0;
+    static cl_kernel gemm_kernel;
+    if(!init){
+        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_nt", "-D BLOCK=" STR(BLOCK) );
+        init = 1;
+    }
+    return gemm_kernel;
+}
+
+cl_kernel get_gemm_tn_kernel()
+{
+    static int init = 0;
+    static cl_kernel gemm_kernel;
+    if(!init){
+        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_tn", "-D BLOCK=" STR(BLOCK) );
+        init = 1;
+    }
+    return gemm_kernel;
+}
+
+cl_kernel get_gemm_nn_kernel()
+{
+    static int init = 0;
+    static cl_kernel gemm_kernel;
+    if(!init){
+        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_nn", "-D BLOCK=" STR(BLOCK) );
+        init = 1;
+    }
+    return gemm_kernel;
+}
+
+void gemm_ongpu_new(int TA, int TB, int M, int N, int K, float ALPHA, 
+        cl_mem A_gpu, int lda, 
+        cl_mem B_gpu, int ldb,
+        float BETA,
+        cl_mem C_gpu, int ldc);
+void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, 
+        cl_mem A_gpu, int lda, 
+        cl_mem B_gpu, int ldb,
+        float BETA,
+        cl_mem C_gpu, int ldc);
+
 void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, 
         cl_mem A_gpu, int lda, 
         cl_mem B_gpu, int ldb,
         float BETA,
         cl_mem C_gpu, int ldc)
 {
+/*
+    cl_setup();
+    cl_command_queue queue = cl.queue;
+    cl_event event;
+    cl.error = clblasSgemm(clblasRowMajor, TA?clblasTrans:clblasNoTrans, TB?clblasTrans:clblasNoTrans,M, N, K,ALPHA, A_gpu, 0, lda,B_gpu, 0, ldb,BETA, C_gpu, 0, ldc,1, &queue, 0, NULL, &event);
+
+*/
+    gemm_ongpu_new(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
+}
+
+void gemm_ongpu_new(int TA, int TB, int M, int N, int K, float ALPHA, 
+        cl_mem A_gpu, int lda, 
+        cl_mem B_gpu, int ldb,
+        float BETA,
+        cl_mem C_gpu, int ldc)
+{
+    //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
+    cl_setup();
+    cl_kernel      gemm_kernel = get_gemm_kernel();
+    if(!TA && !TB) gemm_kernel = get_gemm_nn_kernel();
+    if(!TA && TB)  gemm_kernel = get_gemm_nt_kernel();
+    if(TA && !TB)  gemm_kernel = get_gemm_tn_kernel();
+    cl_command_queue queue = cl.queue;
+
+    cl_uint i = 0;
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
+    check_error(cl);
+
+    const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK};
+    const size_t local_size[] = {BLOCK, BLOCK};
+
+    clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
+    check_error(cl);
+}
+
+void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, 
+        cl_mem A_gpu, int lda, 
+        cl_mem B_gpu, int ldb,
+        float BETA,
+        cl_mem C_gpu, int ldc)
+{
+    //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
     cl_setup();
     cl_kernel gemm_kernel = get_gemm_kernel();
     cl_command_queue queue = cl.queue;
@@ -146,7 +250,7 @@
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
     check_error(cl);
 
-    const size_t global_size[] = {ceil((float)M/BLOCK)*BLOCK, ceil((float)N/BLOCK)*BLOCK};
+    const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK};
     const size_t local_size[] = {BLOCK, BLOCK};
 
     clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
@@ -211,11 +315,44 @@
     float *c = random_matrix(m,n);
     int i;
     clock_t start = clock(), end;
-    for(i = 0; i<1000; ++i){
+    for(i = 0; i<32; ++i){
         gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
     }
     end = clock();
-    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
+    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
+    free(a);
+    free(b);
+    free(c);
+}
+
+void time_ongpu(int TA, int TB, int m, int k, int n)
+{
+    int iter = 128;
+    float *a = random_matrix(m,k);
+    float *b = random_matrix(k,n);
+
+    int lda = (!TA)?k:m;
+    int ldb = (!TB)?n:k;
+
+    float *c = random_matrix(m,n);
+
+    cl_mem a_cl = cl_make_array(a, m*k);
+    cl_mem b_cl = cl_make_array(b, k*n);
+    cl_mem c_cl = cl_make_array(c, m*n);
+
+    int i;
+    clock_t start = clock(), end;
+    for(i = 0; i<iter; ++i){
+        gemm_ongpu(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
+    }
+    double flop = m*n*(2.*k+3.)*iter;
+    double gflop = flop/pow(10., 9);
+    end = clock();
+    double seconds = sec(end-start);
+    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s, %lf GFLOPS\n",m,k,k,n, TA, TB, seconds, gflop/seconds);
+    clReleaseMemObject(a_cl);
+    clReleaseMemObject(b_cl);
+    clReleaseMemObject(c_cl);
     free(a);
     free(b);
     free(c);
@@ -248,14 +385,18 @@
         //printf("%f %f\n", c[i], c_gpu[i]);
         sse += pow(c[i]-c_gpu[i], 2);
     }
-    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g MSE\n",m,k,k,n, TA, TB, sse/(m*n));
+    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g SSE\n",m,k,k,n, TA, TB, sse/(m*n));
     free(a);
     free(b);
     free(c);
+    free(c_gpu);
 }
 
 void test_gpu_blas()
 {
+    /*
+    test_gpu_accuracy(0,0,10,576,75); 
+
     test_gpu_accuracy(0,0,17,10,10); 
     test_gpu_accuracy(1,0,17,10,10); 
     test_gpu_accuracy(0,1,17,10,10); 
@@ -265,18 +406,35 @@
     test_gpu_accuracy(1,0,1000,10,100); 
     test_gpu_accuracy(0,1,1000,10,100); 
     test_gpu_accuracy(1,1,1000,10,100); 
+    */
+    test_gpu_accuracy(0,0,131,4093,1199); 
+    test_gpu_accuracy(0,1,131,4093,1199); 
+    test_gpu_accuracy(1,0,131,4093,1199); 
+    test_gpu_accuracy(1,1,131,4093,1199); 
 
-    time_gpu_random_matrix(0,0,1000,1000,100); 
-    time_random_matrix(0,0,1000,1000,100); 
+    time_ongpu(0,0,1024,1024,1024); 
+    time_ongpu(0,1,1024,1024,1024); 
+    time_ongpu(1,0,1024,1024,1024); 
+    time_ongpu(1,1,1024,1024,1024); 
 
-    time_gpu_random_matrix(0,1,1000,1000,100); 
-    time_random_matrix(0,1,1000,1000,100); 
+    time_ongpu(0,0,128,4096,1200); 
+    time_ongpu(0,1,128,4096,1200); 
+    time_ongpu(1,0,128,4096,1200); 
+    time_ongpu(1,1,128,4096,1200); 
 
-    time_gpu_random_matrix(1,0,1000,1000,100); 
-    time_random_matrix(1,0,1000,1000,100); 
+    /*
+       time_gpu_random_matrix(0,0,1000,1000,100); 
+       time_random_matrix(0,0,1000,1000,100); 
 
-    time_gpu_random_matrix(1,1,1000,1000,100); 
-    time_random_matrix(1,1,1000,1000,100); 
+       time_gpu_random_matrix(0,1,1000,1000,100); 
+       time_random_matrix(0,1,1000,1000,100); 
+
+       time_gpu_random_matrix(1,0,1000,1000,100); 
+       time_random_matrix(1,0,1000,1000,100); 
+
+       time_gpu_random_matrix(1,1,1000,1000,100); 
+       time_random_matrix(1,1,1000,1000,100); 
+     */
 
 }
 #endif

--
Gitblit v1.10.0