Joseph Redmon
2015-01-20 4ac78c89269138b4623993f9f1d81829d8e88131
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
 
__kernel void gemm_nn_fast(int TA, int TB, int M, int N, int K, float ALPHA, 
                    __global float *A, int a_off, int lda, 
                    __global float *B, int b_off, int ldb,
                    float BETA,
                    __global float *C, int c_off, int ldc)
{
    int i, j, k, x, y;
    A += a_off;
    B += b_off;
    C += c_off;
 
    __local float Asub[TILE]  [TILE_K];
    __local float Bsub[TILE_K][TILE];
 
    int ctile = get_group_id(0);
    int rtile = get_group_id(1);
 
    float Areg[TILE];
    float acc[TILE][TILE/THREADS];
 
    A += rtile*TILE*lda;
    B += ctile*TILE;
    C += rtile*TILE*ldc + ctile*TILE;
 
    for(i = 0; i < TILE; ++i){
        for(j = 0; j < TILE/THREADS; ++j){
            acc[i][j] = 0;
        }
    }
 
    int offset = get_local_id(0);
 
    for(i = 0; i < K; i += TILE_K){
        for(j = 0; j < TILE*TILE_K; j += THREADS){
            int index = j+offset;
 
            int row = index / TILE_K;
            int col = index % TILE_K;
            Asub[row][col] = A[row*lda + col];
 
            row = index / TILE;
            col = index % TILE;
            Bsub[row][col] = B[row*ldb + col];
        }
 
        A += TILE_K;
        B += TILE_K*ldb;
 
        barrier(CLK_LOCAL_MEM_FENCE);
 
        for(k = 0; k < TILE_K; ++k){
            #pragma unroll
            for(y = 0; y < TILE; ++y){
                Areg[y] = Asub[y][k];
            }
            for(x = 0; x < TILE; x += THREADS){
                float Breg = Bsub[k][x+offset];
                #pragma unroll
                for(y = 0; y < TILE; ++y){
                    acc[y][x/THREADS] += Breg * Areg[y];
                }
            }
        }
        barrier(CLK_LOCAL_MEM_FENCE);
    }
 
    for(i = 0; i < TILE; ++i){
        for(j = 0; j < TILE/THREADS; ++j){
            int col = j*THREADS + offset;
            int row = i;
            C[row*ldc+col] = ALPHA*acc[i][j] + BETA*C[row*ldc+col];
        }
    }
}