From f7a17f82eb43de864a4f980f235055da9685eef8 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 29 Jan 2014 00:28:42 +0000
Subject: [PATCH] Convolutional layers working w/ matrices
---
src/matrix.c | 20 ++++++++++----------
1 files changed, 10 insertions(+), 10 deletions(-)
diff --git a/src/matrix.c b/src/matrix.c
index 68e6f8d..96bd332 100644
--- a/src/matrix.c
+++ b/src/matrix.c
@@ -13,7 +13,7 @@
free(m.vals);
}
-double matrix_accuracy(matrix truth, matrix guess)
+float matrix_accuracy(matrix truth, matrix guess)
{
int k = truth.cols;
int i;
@@ -22,7 +22,7 @@
int class = max_index(guess.vals[i], k);
if(truth.vals[i][class]) ++count;
}
- return (double)count/truth.rows;
+ return (float)count/truth.rows;
}
void matrix_add_matrix(matrix from, matrix to)
@@ -42,9 +42,9 @@
matrix m;
m.rows = rows;
m.cols = cols;
- m.vals = calloc(m.rows, sizeof(double *));
+ m.vals = calloc(m.rows, sizeof(float *));
for(i = 0; i < m.rows; ++i){
- m.vals[i] = calloc(m.cols, sizeof(double));
+ m.vals[i] = calloc(m.cols, sizeof(float));
}
return m;
}
@@ -55,7 +55,7 @@
matrix h;
h.rows = n;
h.cols = m->cols;
- h.vals = calloc(h.rows, sizeof(double *));
+ h.vals = calloc(h.rows, sizeof(float *));
for(i = 0; i < n; ++i){
int index = rand()%m->rows;
h.vals[i] = m->vals[index];
@@ -64,9 +64,9 @@
return h;
}
-double *pop_column(matrix *m, int c)
+float *pop_column(matrix *m, int c)
{
- double *col = calloc(m->rows, sizeof(double));
+ float *col = calloc(m->rows, sizeof(float));
int i, j;
for(i = 0; i < m->rows; ++i){
col[i] = m->vals[i][c];
@@ -90,18 +90,18 @@
int n = 0;
int size = 1024;
- m.vals = calloc(size, sizeof(double*));
+ m.vals = calloc(size, sizeof(float*));
while((line = fgetl(fp))){
if(m.cols == -1) m.cols = count_fields(line);
if(n == size){
size *= 2;
- m.vals = realloc(m.vals, size*sizeof(double*));
+ m.vals = realloc(m.vals, size*sizeof(float*));
}
m.vals[n] = parse_fields(line, m.cols);
free(line);
++n;
}
- m.vals = realloc(m.vals, n*sizeof(double*));
+ m.vals = realloc(m.vals, n*sizeof(float*));
m.rows = n;
return m;
}
--
Gitblit v1.10.0