From 4ac78c89269138b4623993f9f1d81829d8e88131 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 20 Jan 2015 21:26:46 +0000
Subject: [PATCH] I am so done with opencl, switching to cuda

---
 src/gemm_fast.cl |   37 +++++++++++++++++--------------------
 src/gemm.c       |    6 +-----
 src/cnn.c        |    2 +-
 3 files changed, 19 insertions(+), 26 deletions(-)

diff --git a/src/cnn.c b/src/cnn.c
index be93e8c..fed69d0 100644
--- a/src/cnn.c
+++ b/src/cnn.c
@@ -210,7 +210,7 @@
     //network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
     srand(time(0));
     network net = parse_network_cfg(cfgfile);
-    set_learning_network(&net, net.learning_rate*10., net.momentum, net.decay);
+    set_learning_network(&net, net.learning_rate*100., net.momentum, net.decay);
     printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
     int imgs = 1024;
     int i = 6600;
diff --git a/src/gemm.c b/src/gemm.c
index 8394991..9797b85 100644
--- a/src/gemm.c
+++ b/src/gemm.c
@@ -164,8 +164,7 @@
 
 #define TILE 64
 #define TILE_K 16
-#define WPT 8
-#define THREADS (TILE*TILE)/(WPT*WPT)
+#define THREADS 64
 
 cl_kernel get_gemm_nn_fast_kernel()
 {
@@ -175,7 +174,6 @@
         gemm_kernel = get_kernel("src/gemm_fast.cl", "gemm_nn_fast", "-D TILE=" STR(TILE)
                                                                     " -cl-nv-verbose "
                                                                     " -D TILE_K=" STR(TILE_K)
-                                                                    " -D WPT=" STR(WPT)
                                                                     " -D THREADS=" STR(THREADS));
         init = 1;
     }
@@ -464,7 +462,6 @@
 
     test_gpu_accuracy(0,0,128,128,128); 
 
-/*
     time_ongpu(0,0,64,2916,363); 
     time_ongpu_fast(0,0,64,2916,363); 
     time_ongpu(0,0,64,2916,363); 
@@ -483,7 +480,6 @@
     time_ongpu_fast(0,0,128,4096,12544); 
     time_ongpu(0,0,128,4096,4096); 
     time_ongpu_fast(0,0,128,4096,4096); 
-    */
 //    time_ongpu(1,0,2304,196,256); 
 //    time_ongpu_fast(1,0,2304,196,256); 
 //    time_ongpu(0,1,256,2304,196); 
diff --git a/src/gemm_fast.cl b/src/gemm_fast.cl
index 9a98208..2a76396 100644
--- a/src/gemm_fast.cl
+++ b/src/gemm_fast.cl
@@ -16,16 +16,15 @@
     int ctile = get_group_id(0);
     int rtile = get_group_id(1);
 
-    float Breg;
-    float Areg[WPT];
-    float acc[WPT][WPT];
+    float Areg[TILE];
+    float acc[TILE][TILE/THREADS];
 
     A += rtile*TILE*lda;
     B += ctile*TILE;
     C += rtile*TILE*ldc + ctile*TILE;
 
-    for(i = 0; i < WPT; ++i){
-        for(j = 0; j < WPT; ++j){
+    for(i = 0; i < TILE; ++i){
+        for(j = 0; j < TILE/THREADS; ++j){
             acc[i][j] = 0;
         }
     }
@@ -51,28 +50,26 @@
         barrier(CLK_LOCAL_MEM_FENCE);
 
         for(k = 0; k < TILE_K; ++k){
-            for(y = 0; y < WPT; ++y){
-                int row = (offset + (y*WPT)*THREADS)/TILE;
-                //Areg[y] = Asub[y*WPT][k];
+            #pragma unroll
+            for(y = 0; y < TILE; ++y){
+                Areg[y] = Asub[y][k];
             }
-            for(y = 0; y < WPT; ++y){
-                for(x = 0; x < WPT; ++x){
-                    int index = offset + (y*WPT + x)*THREADS;
-                    int row = index / TILE;
-                    int col = index % TILE;
-                    acc[y][x] += Asub[row][k]*Bsub[k][col];
+            for(x = 0; x < TILE; x += THREADS){
+                float Breg = Bsub[k][x+offset];
+                #pragma unroll
+                for(y = 0; y < TILE; ++y){
+                    acc[y][x/THREADS] += Breg * Areg[y];
                 }
             }
         }
         barrier(CLK_LOCAL_MEM_FENCE);
     }
 
-    for(y = 0; y < WPT; ++y){
-        for(x = 0; x < WPT; ++x){
-            int index = offset + (y*WPT + x)*THREADS;
-            int row = index / TILE;
-            int col = index % TILE;
-            C[row*ldc+col] = ALPHA*acc[y][x] + BETA*C[row*ldc+col];
+    for(i = 0; i < TILE; ++i){
+        for(j = 0; j < TILE/THREADS; ++j){
+            int col = j*THREADS + offset;
+            int row = i;
+            C[row*ldc+col] = ALPHA*acc[i][j] + BETA*C[row*ldc+col];
         }
     }
 }

--
Gitblit v1.10.0