From e36182cd8c5dd5c6d0aa1f77cf5cdca87e8bb1f0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 21 Nov 2014 23:35:19 +0000
Subject: [PATCH] cleaned up data parsing a lot. probably nothing broken?

---
 src/matrix.c |   32 +++++++++++++++++++++++---------
 1 files changed, 23 insertions(+), 9 deletions(-)

diff --git a/src/matrix.c b/src/matrix.c
index 5627b87..96bd332 100644
--- a/src/matrix.c
+++ b/src/matrix.c
@@ -13,6 +13,18 @@
     free(m.vals);
 }
 
+float matrix_accuracy(matrix truth, matrix guess)
+{
+    int k = truth.cols;
+    int i;
+    int count = 0;
+    for(i = 0; i < truth.rows; ++i){
+        int class = max_index(guess.vals[i], k);
+        if(truth.vals[i][class]) ++count;
+    }
+    return (float)count/truth.rows;
+}
+
 void matrix_add_matrix(matrix from, matrix to)
 {
     assert(from.rows == to.rows && from.cols == to.cols);
@@ -26,12 +38,14 @@
 
 matrix make_matrix(int rows, int cols)
 {
+    int i;
     matrix m;
     m.rows = rows;
     m.cols = cols;
-    m.vals = calloc(m.rows, sizeof(double *));
-    int i;
-    for(i = 0; i < m.rows; ++i) m.vals[i] = calloc(m.cols, sizeof(double));
+    m.vals = calloc(m.rows, sizeof(float *));
+    for(i = 0; i < m.rows; ++i){
+        m.vals[i] = calloc(m.cols, sizeof(float));
+    }
     return m;
 }
 
@@ -41,7 +55,7 @@
     matrix h;
     h.rows = n;
     h.cols = m->cols;
-    h.vals = calloc(h.rows, sizeof(double *));
+    h.vals = calloc(h.rows, sizeof(float *));
     for(i = 0; i < n; ++i){
         int index = rand()%m->rows;
         h.vals[i] = m->vals[index];
@@ -50,9 +64,9 @@
     return h;
 }
 
-double *pop_column(matrix *m, int c)
+float *pop_column(matrix *m, int c)
 {
-    double *col = calloc(m->rows, sizeof(double));
+    float *col = calloc(m->rows, sizeof(float));
     int i, j;
     for(i = 0; i < m->rows; ++i){
         col[i] = m->vals[i][c];
@@ -76,18 +90,18 @@
 
 	int n = 0;
 	int size = 1024;
-	m.vals = calloc(size, sizeof(double*));
+	m.vals = calloc(size, sizeof(float*));
 	while((line = fgetl(fp))){
         if(m.cols == -1) m.cols = count_fields(line);
 		if(n == size){
 			size *= 2;
-			m.vals = realloc(m.vals, size*sizeof(double*));
+			m.vals = realloc(m.vals, size*sizeof(float*));
 		}
 		m.vals[n] = parse_fields(line, m.cols);
 		free(line);
 		++n;
 	}
-	m.vals = realloc(m.vals, n*sizeof(double*));
+	m.vals = realloc(m.vals, n*sizeof(float*));
     m.rows = n;
 	return m;
 }

--
Gitblit v1.10.0