From b2b7137b6f185ce2f01664d782a09b08d50d5a07 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 28 Jan 2014 07:16:56 +0000
Subject: [PATCH] About to do something stupid...

---
 src/mini_blas.c |   38 +++++++++++++++++++++++++++++++++-----
 1 files changed, 33 insertions(+), 5 deletions(-)

diff --git a/src/mini_blas.c b/src/mini_blas.c
index b15ba8e..3af36e5 100644
--- a/src/mini_blas.c
+++ b/src/mini_blas.c
@@ -1,16 +1,44 @@
 
+#include <stdlib.h>
+#include <math.h>
+
+void pm(int M, int N, double *A)
+{
+    int i,j;
+    for(i =0 ; i < M; ++i){
+        for(j = 0; j < N; ++j){
+            printf("%10.6f, ", A[i*N+j]);
+        }
+        printf("\n");
+    }
+    printf("\n");
+}
+
 void gemm(int TA, int TB, int M, int N, int K, double ALPHA, 
                     double *A, int lda, 
                     double *B, int ldb,
                     double BETA,
                     double *C, int ldc)
 {
-    // Assume TA = TB = 0, beta = 1 LULZ
+    // Assume TA = 0, beta = 1 LULZ
     int i,j,k;
-    for(i = 0; i < M; ++i){
-        for(k = 0; k < K; ++k){
+    if(TB && !TA){
+        for(i = 0; i < M; ++i){
             for(j = 0; j < N; ++j){
-                C[i*ldc+j] += ALPHA*A[i*lda+k]*B[k*ldb+j];
+                register double sum = 0;
+                for(k = 0; k < K; ++k){
+                    sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
+                }
+                C[i*ldc+j] += sum;
+            }
+        }
+    }else{
+        for(i = 0; i < M; ++i){
+            for(k = 0; k < K; ++k){
+                register double A_PART = ALPHA*A[i*lda+k];
+                for(j = 0; j < N; ++j){
+                    C[i*ldc+j] += A_PART*B[k*ldb+j];
+                }
             }
         }
     }
@@ -59,7 +87,7 @@
 void im2col_cpu(double* data_im, const int channels,
         const int height, const int width, const int ksize, const int stride,
         double* data_col) 
- {
+{
     int c,h,w;
     int height_col = (height - ksize) / stride + 1;
     int width_col = (width - ksize) / stride + 1;

--
Gitblit v1.10.0