From 6e1d5b45de988bb795c4c505f22f2170a78b7746 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Jan 2015 06:06:18 +0000
Subject: [PATCH] fast sort of working

---
 src/gemm.c |  261 +++++++++++++++++++++++++++++++++++----------------
 1 files changed, 177 insertions(+), 84 deletions(-)

diff --git a/src/gemm.c b/src/gemm.c
index 63c2950..8394991 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -104,7 +104,10 @@
 
 #include "opencl.h"
 #include <math.h>
+
+#ifdef CLBLAS
 #include <clBLAS.h>
+#endif
 
 #define STR_HELPER(x) #x
 #define STR(x) STR_HELPER(x)
@@ -131,7 +134,7 @@
     static int init = 0;
     static cl_kernel gemm_kernel;
     if(!init){
-        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_nt", "-D BLOCK=" STR(BLOCK) );
+        gemm_kernel = get_kernel("src/gemm.cl", "gemm_nt", "-D BLOCK=" STR(BLOCK) );
         init = 1;
     }
     return gemm_kernel;
@@ -142,7 +145,7 @@
     static int init = 0;
     static cl_kernel gemm_kernel;
     if(!init){
-        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_tn", "-D BLOCK=" STR(BLOCK) );
+        gemm_kernel = get_kernel("src/gemm.cl", "gemm_tn", "-D BLOCK=" STR(BLOCK) );
         init = 1;
     }
     return gemm_kernel;
@@ -153,22 +156,31 @@
     static int init = 0;
     static cl_kernel gemm_kernel;
     if(!init){
-        gemm_kernel = get_kernel("src/gemm_new.cl", "gemm_nn", "-D BLOCK=" STR(BLOCK) );
+        gemm_kernel = get_kernel("src/gemm.cl", "gemm_nn", "-D BLOCK=" STR(BLOCK) );
         init = 1;
     }
     return gemm_kernel;
 }
 
-void gemm_ongpu_new(int TA, int TB, int M, int N, int K, float ALPHA, 
-        cl_mem A_gpu, int lda, 
-        cl_mem B_gpu, int ldb,
-        float BETA,
-        cl_mem C_gpu, int ldc);
-void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, 
-        cl_mem A_gpu, int lda, 
-        cl_mem B_gpu, int ldb,
-        float BETA,
-        cl_mem C_gpu, int ldc);
+#define TILE 64
+#define TILE_K 16
+#define WPT 8
+#define THREADS (TILE*TILE)/(WPT*WPT)
+
+cl_kernel get_gemm_nn_fast_kernel()
+{
+    static int init = 0;
+    static cl_kernel gemm_kernel;
+    if(!init){
+        gemm_kernel = get_kernel("src/gemm_fast.cl", "gemm_nn_fast", "-D TILE=" STR(TILE)
+                                                                    " -cl-nv-verbose "
+                                                                    " -D TILE_K=" STR(TILE_K)
+                                                                    " -D WPT=" STR(WPT)
+                                                                    " -D THREADS=" STR(THREADS));
+        init = 1;
+    }
+    return gemm_kernel;
+}
 
 void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA, 
         cl_mem A_gpu, int lda, 
@@ -176,24 +188,61 @@
         float BETA,
         cl_mem C_gpu, int ldc)
 {
-/*
-    cl_setup();
-    cl_command_queue queue = cl.queue;
-    cl_event event;
-    cl.error = clblasSgemm(clblasRowMajor, TA?clblasTrans:clblasNoTrans, TB?clblasTrans:clblasNoTrans,M, N, K,ALPHA, A_gpu, 0, lda,B_gpu, 0, ldb,BETA, C_gpu, 0, ldc,1, &queue, 0, NULL, &event);
-
-*/
-    gemm_ongpu_new(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
+    gemm_ongpu_offset(TA, TB, M, N, K, ALPHA, A_gpu, 0, lda, B_gpu, 0, ldb, BETA, C_gpu, 0, ldc);
 }
 
-void gemm_ongpu_new(int TA, int TB, int M, int N, int K, float ALPHA, 
+void gemm_ongpu_fast(int TA, int TB, int M, int N, int K, float ALPHA, 
         cl_mem A_gpu, int lda, 
         cl_mem B_gpu, int ldb,
         float BETA,
         cl_mem C_gpu, int ldc)
 {
+    int a_off = 0;
+    int b_off = 0;
+    int c_off = 0;
     //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
-    cl_setup();
+    cl_kernel      gemm_kernel = get_gemm_nn_fast_kernel();
+    cl_command_queue queue = cl.queue;
+
+    cl_uint i = 0;
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(a_off), (void*) &a_off);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(b_off), (void*) &b_off);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(c_off), (void*) &c_off);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
+    check_error(cl);
+
+    const size_t global_size[] = {THREADS*((N-1)/TILE + 1), (M-1)/TILE + 1};
+    const size_t local_size[] = {THREADS, 1};
+
+    cl.error = clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
+    check_error(cl);
+}
+
+void gemm_ongpu_offset(int TA, int TB, int M, int N, int K, float ALPHA, 
+        cl_mem A_gpu, int a_off, int lda, 
+        cl_mem B_gpu, int b_off, int ldb,
+        float BETA,
+        cl_mem C_gpu, int c_off, int ldc)
+{
+#ifdef CLBLAS
+    cl_command_queue queue = cl.queue;
+    cl_event event;
+    cl.error = clblasSgemm(clblasRowMajor, TA?clblasTrans:clblasNoTrans, TB?clblasTrans:clblasNoTrans,M, N, K,ALPHA, A_gpu, a_off, lda,B_gpu, b_off, ldb,BETA, C_gpu, c_off, ldc,1, &queue, 0, NULL, &event);
+    check_error(cl);
+#else
+    //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
     cl_kernel      gemm_kernel = get_gemm_kernel();
     if(!TA && !TB) gemm_kernel = get_gemm_nn_kernel();
     if(!TA && TB)  gemm_kernel = get_gemm_nt_kernel();
@@ -208,63 +257,31 @@
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(a_off), (void*) &a_off);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(b_off), (void*) &b_off);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
+    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(c_off), (void*) &c_off);
     cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
     check_error(cl);
 
     const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK};
     const size_t local_size[] = {BLOCK, BLOCK};
 
-    clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
+    cl.error = clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
     check_error(cl);
+#endif
 }
 
-void gemm_ongpu_old(int TA, int TB, int M, int N, int K, float ALPHA, 
-        cl_mem A_gpu, int lda, 
-        cl_mem B_gpu, int ldb,
-        float BETA,
-        cl_mem C_gpu, int ldc)
-{
-    //printf("gpu: %d %d %d %d %d\n",TA, TB, M, N, K);
-    cl_setup();
-    cl_kernel gemm_kernel = get_gemm_kernel();
-    cl_command_queue queue = cl.queue;
-
-    cl_uint i = 0;
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
-    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
-    check_error(cl);
-
-    const size_t global_size[] = {ceil((float)N/BLOCK)*BLOCK, ceil((float)M/BLOCK)*BLOCK};
-    const size_t local_size[] = {BLOCK, BLOCK};
-
-    clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
-    check_error(cl);
-}
-
-
 void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA, 
         float *A, int lda, 
         float *B, int ldb,
         float BETA,
         float *C, int ldc)
 {
-    cl_setup();
     cl_context context = cl.context;
     cl_command_queue queue = cl.queue;
 
@@ -286,7 +303,9 @@
             size, C, &cl.error);
     check_error(cl);
 
-    gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
+    // TODO
+    //gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
+    gemm_ongpu_fast(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
 
     clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0);
     check_error(cl);
@@ -327,7 +346,7 @@
 
 void time_ongpu(int TA, int TB, int m, int k, int n)
 {
-    int iter = 128;
+    int iter = 10;
     float *a = random_matrix(m,k);
     float *b = random_matrix(k,n);
 
@@ -345,7 +364,7 @@
     for(i = 0; i<iter; ++i){
         gemm_ongpu(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
     }
-    double flop = m*n*(2.*k+3.)*iter;
+    double flop = ((double)m)*n*(2.*k + 2.)*iter;
     double gflop = flop/pow(10., 9);
     end = clock();
     double seconds = sec(end-start);
@@ -358,6 +377,39 @@
     free(c);
 }
 
+void time_ongpu_fast(int TA, int TB, int m, int k, int n)
+{
+    int iter = 10;
+    float *a = random_matrix(m,k);
+    float *b = random_matrix(k,n);
+
+    int lda = (!TA)?k:m;
+    int ldb = (!TB)?n:k;
+
+    float *c = random_matrix(m,n);
+
+    cl_mem a_cl = cl_make_array(a, m*k);
+    cl_mem b_cl = cl_make_array(b, k*n);
+    cl_mem c_cl = cl_make_array(c, m*n);
+
+    int i;
+    clock_t start = clock(), end;
+    for(i = 0; i<iter; ++i){
+        gemm_ongpu_fast(TA,TB,m,n,k,1,a_cl,lda,b_cl,ldb,1,c_cl,n);
+    }
+    double flop = ((double)m)*n*(2.*k + 2.)*iter;
+    double gflop = flop/pow(10., 9);
+    end = clock();
+    double seconds = sec(end-start);
+    printf("Fast   Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf s, %lf GFLOPS\n",m,k,k,n, TA, TB, seconds, gflop/seconds);
+    clReleaseMemObject(a_cl);
+    clReleaseMemObject(b_cl);
+    clReleaseMemObject(c_cl);
+    free(a);
+    free(b);
+    free(c);
+}
+
 void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
 {
     srand(0);
@@ -377,8 +429,10 @@
     int i;
     //pm(m,k,b);
     gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n);
+    //printf("GPU\n");
     //pm(m, n, c_gpu);
     gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
+    //printf("\n\nCPU\n");
     //pm(m, n, c);
     double sse = 0;
     for(i = 0; i < m*n; ++i) {
@@ -395,32 +449,71 @@
 void test_gpu_blas()
 {
     /*
-    test_gpu_accuracy(0,0,10,576,75); 
+       test_gpu_accuracy(0,0,10,576,75); 
 
-    test_gpu_accuracy(0,0,17,10,10); 
-    test_gpu_accuracy(1,0,17,10,10); 
-    test_gpu_accuracy(0,1,17,10,10); 
-    test_gpu_accuracy(1,1,17,10,10); 
+       test_gpu_accuracy(0,0,17,10,10); 
+       test_gpu_accuracy(1,0,17,10,10); 
+       test_gpu_accuracy(0,1,17,10,10); 
+       test_gpu_accuracy(1,1,17,10,10); 
 
-    test_gpu_accuracy(0,0,1000,10,100); 
-    test_gpu_accuracy(1,0,1000,10,100); 
-    test_gpu_accuracy(0,1,1000,10,100); 
-    test_gpu_accuracy(1,1,1000,10,100); 
+       test_gpu_accuracy(0,0,1000,10,100); 
+       test_gpu_accuracy(1,0,1000,10,100); 
+       test_gpu_accuracy(0,1,1000,10,100); 
+       test_gpu_accuracy(1,1,1000,10,100); 
+     */
+
+    test_gpu_accuracy(0,0,128,128,128); 
+
+/*
+    time_ongpu(0,0,64,2916,363); 
+    time_ongpu_fast(0,0,64,2916,363); 
+    time_ongpu(0,0,64,2916,363); 
+    time_ongpu_fast(0,0,64,2916,363); 
+    time_ongpu(0,0,64,2916,363); 
+    time_ongpu_fast(0,0,64,2916,363); 
+    time_ongpu(0,0,192,729,1600); 
+    time_ongpu_fast(0,0,192,729,1600); 
+    time_ongpu(0,0,384,196,1728); 
+    time_ongpu_fast(0,0,384,196,1728); 
+    time_ongpu(0,0,256,196,3456); 
+    time_ongpu_fast(0,0,256,196,3456); 
+    time_ongpu(0,0,256,196,2304); 
+    time_ongpu_fast(0,0,256,196,2304); 
+    time_ongpu(0,0,128,4096,12544); 
+    time_ongpu_fast(0,0,128,4096,12544); 
+    time_ongpu(0,0,128,4096,4096); 
+    time_ongpu_fast(0,0,128,4096,4096); 
     */
-    test_gpu_accuracy(0,0,131,4093,1199); 
-    test_gpu_accuracy(0,1,131,4093,1199); 
-    test_gpu_accuracy(1,0,131,4093,1199); 
-    test_gpu_accuracy(1,1,131,4093,1199); 
+//    time_ongpu(1,0,2304,196,256); 
+//    time_ongpu_fast(1,0,2304,196,256); 
+//    time_ongpu(0,1,256,2304,196); 
+//    time_ongpu_fast(0,1,256,2304,196); 
 
-    time_ongpu(0,0,1024,1024,1024); 
-    time_ongpu(0,1,1024,1024,1024); 
-    time_ongpu(1,0,1024,1024,1024); 
-    time_ongpu(1,1,1024,1024,1024); 
+    time_ongpu(0,0,2048,2048,2048); 
+    time_ongpu_fast(0,0,2048,2048,2048); 
+    time_ongpu(0,0,2048,2048,2048); 
+    time_ongpu_fast(0,0,2048,2048,2048); 
+    time_ongpu(0,0,2048,2048,2048); 
+    time_ongpu_fast(0,0,2048,2048,2048); 
 
-    time_ongpu(0,0,128,4096,1200); 
-    time_ongpu(0,1,128,4096,1200); 
-    time_ongpu(1,0,128,4096,1200); 
-    time_ongpu(1,1,128,4096,1200); 
+    /*
+       test_gpu_accuracy(0,0,131,4093,1199); 
+       test_gpu_accuracy(0,1,131,4093,1199); 
+       test_gpu_accuracy(1,0,131,4093,1199); 
+       test_gpu_accuracy(1,1,131,4093,1199); 
+     */
+    /*
+
+       time_ongpu(0,0,1024,1024,1024); 
+       time_ongpu(0,1,1024,1024,1024); 
+       time_ongpu(1,0,1024,1024,1024); 
+       time_ongpu(1,1,1024,1024,1024); 
+
+       time_ongpu(0,0,128,4096,1200); 
+       time_ongpu(0,1,128,4096,1200); 
+       time_ongpu(1,0,128,4096,1200); 
+       time_ongpu(1,1,128,4096,1200); 
+     */
 
     /*
        time_gpu_random_matrix(0,0,1000,1000,100); 

--
Gitblit v1.10.0