From f7a17f82eb43de864a4f980f235055da9685eef8 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 29 Jan 2014 00:28:42 +0000
Subject: [PATCH] Convolutional layers working w/ matrices

---
 src/mini_blas.c           |   93 +++++
 Makefile                  |    2 
 src/convolutional_layer.h |   24 
 src/image.c               |   68 ++--
 src/mini_blas.h           |   24 +
 src/tests.c               |  129 +++-----
 src/image.h               |   20 
 src/maxpool_layer.c       |   28 +-
 src/utils.c               |   44 +-
 nist_basic.cfg            |    2 
 src/network.c             |   70 ++--
 src/matrix.c              |   20 
 src/maxpool_layer.h       |    8 
 src/matrix.h              |    6 
 src/softmax_layer.h       |    8 
 src/utils.h               |   20 
 src/network.h             |   24 
 src/connected_layer.c     |   58 ++--
 src/connected_layer.h     |   24 
 src/data.c                |   20 
 src/softmax_layer.c       |   28 +
 src/option_list.h         |    2 
 src/convolutional_layer.c |   90 ++++--
 src/option_list.c         |    2 
 src/activations.h         |    4 
 src/activations.c         |    4 
 26 files changed, 459 insertions(+), 363 deletions(-)

diff --git a/Makefile b/Makefile
index dc08b46..fda7d88 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
 CC=gcc
 COMMON=-Wall `pkg-config --cflags opencv`
-CFLAGS= $(COMMON) -Ofast -ffast-math -flto
+CFLAGS= $(COMMON) -O3 -ffast-math -flto
 UNAME = $(shell uname)
 ifeq ($(UNAME), Darwin)
 COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
diff --git a/nist_basic.cfg b/nist_basic.cfg
index f5ea0a3..7142735 100644
--- a/nist_basic.cfg
+++ b/nist_basic.cfg
@@ -3,7 +3,7 @@
 height=28
 channels=1
 filters=20
-size=5
+size=11
 stride=1
 activation=linear
 
diff --git a/src/activations.c b/src/activations.c
index b8bb79d..cc923d0 100644
--- a/src/activations.c
+++ b/src/activations.c
@@ -15,7 +15,7 @@
     return RELU;
 }
 
-double activate(double x, ACTIVATION a){
+float activate(float x, ACTIVATION a){
     switch(a){
         case LINEAR:
             return x;
@@ -30,7 +30,7 @@
     }
     return 0;
 }
-double gradient(double x, ACTIVATION a){
+float gradient(float x, ACTIVATION a){
     switch(a){
         case LINEAR:
             return 1;
diff --git a/src/activations.h b/src/activations.h
index 889453f..fb2c54f 100644
--- a/src/activations.h
+++ b/src/activations.h
@@ -7,8 +7,8 @@
 
 ACTIVATION get_activation(char *s);
 
-double activate(double x, ACTIVATION a);
-double gradient(double x, ACTIVATION a);
+float activate(float x, ACTIVATION a);
+float gradient(float x, ACTIVATION a);
 
 #endif
 
diff --git a/src/connected_layer.c b/src/connected_layer.c
index 6871b2e..5f6631c 100644
--- a/src/connected_layer.c
+++ b/src/connected_layer.c
@@ -15,19 +15,19 @@
     layer->inputs = inputs;
     layer->outputs = outputs;
 
-    layer->output = calloc(outputs, sizeof(double*));
-    layer->delta = calloc(outputs, sizeof(double*));
+    layer->output = calloc(outputs, sizeof(float*));
+    layer->delta = calloc(outputs, sizeof(float*));
 
-    layer->weight_updates = calloc(inputs*outputs, sizeof(double));
-    layer->weight_momentum = calloc(inputs*outputs, sizeof(double));
-    layer->weights = calloc(inputs*outputs, sizeof(double));
-    double scale = 2./inputs;
+    layer->weight_updates = calloc(inputs*outputs, sizeof(float));
+    layer->weight_momentum = calloc(inputs*outputs, sizeof(float));
+    layer->weights = calloc(inputs*outputs, sizeof(float));
+    float scale = 2./inputs;
     for(i = 0; i < inputs*outputs; ++i)
         layer->weights[i] = rand_normal()*scale;
 
-    layer->bias_updates = calloc(outputs, sizeof(double));
-    layer->bias_momentum = calloc(outputs, sizeof(double));
-    layer->biases = calloc(outputs, sizeof(double));
+    layer->bias_updates = calloc(outputs, sizeof(float));
+    layer->bias_momentum = calloc(outputs, sizeof(float));
+    layer->biases = calloc(outputs, sizeof(float));
     for(i = 0; i < outputs; ++i)
         //layer->biases[i] = rand_normal()*scale + scale;
         layer->biases[i] = 0;
@@ -36,7 +36,7 @@
     return layer;
 }
 
-void update_connected_layer(connected_layer layer, double step, double momentum, double decay)
+void update_connected_layer(connected_layer layer, float step, float momentum, float decay)
 {
     int i;
     for(i = 0; i < layer.outputs; ++i){
@@ -47,27 +47,27 @@
         layer.weight_momentum[i] = step*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i];
         layer.weights[i] += layer.weight_momentum[i];
     }
-    memset(layer.bias_updates, 0, layer.outputs*sizeof(double));
-    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double));
+    memset(layer.bias_updates, 0, layer.outputs*sizeof(float));
+    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float));
 }
 
-void forward_connected_layer(connected_layer layer, double *input)
+void forward_connected_layer(connected_layer layer, float *input)
 {
     int i;
-    memcpy(layer.output, layer.biases, layer.outputs*sizeof(double));
+    memcpy(layer.output, layer.biases, layer.outputs*sizeof(float));
     int m = 1;
     int k = layer.inputs;
     int n = layer.outputs;
-    double *a = input;
-    double *b = layer.weights;
-    double *c = layer.output;
+    float *a = input;
+    float *b = layer.weights;
+    float *c = layer.output;
     gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
     for(i = 0; i < layer.outputs; ++i){
         layer.output[i] = activate(layer.output[i], layer.activation);
     }
 }
 
-void learn_connected_layer(connected_layer layer, double *input)
+void learn_connected_layer(connected_layer layer, float *input)
 {
     int i;
     for(i = 0; i < layer.outputs; ++i){
@@ -77,28 +77,28 @@
     int m = layer.inputs;
     int k = 1;
     int n = layer.outputs;
-    double *a = input;
-    double *b = layer.delta;
-    double *c = layer.weight_updates;
+    float *a = input;
+    float *b = layer.delta;
+    float *c = layer.weight_updates;
     gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
 }
 
-void backward_connected_layer(connected_layer layer, double *input, double *delta)
+void backward_connected_layer(connected_layer layer, float *input, float *delta)
 {
-    memset(delta, 0, layer.inputs*sizeof(double));
+    memset(delta, 0, layer.inputs*sizeof(float));
 
     int m = layer.inputs;
     int k = layer.outputs;
     int n = 1;
 
-    double *a = layer.weights;
-    double *b = layer.delta;
-    double *c = delta;
+    float *a = layer.weights;
+    float *b = layer.delta;
+    float *c = delta;
 
     gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
 }
 /*
-   void forward_connected_layer(connected_layer layer, double *input)
+   void forward_connected_layer(connected_layer layer, float *input)
    {
    int i, j;
    for(i = 0; i < layer.outputs; ++i){
@@ -109,7 +109,7 @@
    layer.output[i] = activate(layer.output[i], layer.activation);
    }
    }
-   void learn_connected_layer(connected_layer layer, double *input)
+   void learn_connected_layer(connected_layer layer, float *input)
    {
    int i, j;
    for(i = 0; i < layer.outputs; ++i){
@@ -120,7 +120,7 @@
    }
    }
    }
-   void backward_connected_layer(connected_layer layer, double *input, double *delta)
+   void backward_connected_layer(connected_layer layer, float *input, float *delta)
    {
    int i, j;
 
diff --git a/src/connected_layer.h b/src/connected_layer.h
index 05fb261..ce0181d 100644
--- a/src/connected_layer.h
+++ b/src/connected_layer.h
@@ -6,17 +6,17 @@
 typedef struct{
     int inputs;
     int outputs;
-    double *weights;
-    double *biases;
+    float *weights;
+    float *biases;
 
-    double *weight_updates;
-    double *bias_updates;
+    float *weight_updates;
+    float *bias_updates;
 
-    double *weight_momentum;
-    double *bias_momentum;
+    float *weight_momentum;
+    float *bias_momentum;
 
-    double *output;
-    double *delta;
+    float *output;
+    float *delta;
 
     ACTIVATION activation;
 
@@ -24,10 +24,10 @@
 
 connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activation);
 
-void forward_connected_layer(connected_layer layer, double *input);
-void backward_connected_layer(connected_layer layer, double *input, double *delta);
-void learn_connected_layer(connected_layer layer, double *input);
-void update_connected_layer(connected_layer layer, double step, double momentum, double decay);
+void forward_connected_layer(connected_layer layer, float *input);
+void backward_connected_layer(connected_layer layer, float *input, float *delta);
+void learn_connected_layer(connected_layer layer, float *input);
+void update_connected_layer(connected_layer layer, float step, float momentum, float decay);
 
 
 #endif
diff --git a/src/convolutional_layer.c b/src/convolutional_layer.c
index 53eb7bf..cdfe9e1 100644
--- a/src/convolutional_layer.c
+++ b/src/convolutional_layer.c
@@ -9,7 +9,7 @@
     h = layer.out_h;
     w = layer.out_w;
     c = layer.n;
-    return double_to_image(h,w,c,layer.output);
+    return float_to_image(h,w,c,layer.output);
 }
 
 image get_convolutional_delta(convolutional_layer layer)
@@ -18,7 +18,7 @@
     h = layer.out_h;
     w = layer.out_w;
     c = layer.n;
-    return double_to_image(h,w,c,layer.delta);
+    return float_to_image(h,w,c,layer.delta);
 }
 
 convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
@@ -34,14 +34,14 @@
     layer->stride = stride;
     layer->size = size;
 
-    layer->filters = calloc(c*n*size*size, sizeof(double));
-    layer->filter_updates = calloc(c*n*size*size, sizeof(double));
-    layer->filter_momentum = calloc(c*n*size*size, sizeof(double));
+    layer->filters = calloc(c*n*size*size, sizeof(float));
+    layer->filter_updates = calloc(c*n*size*size, sizeof(float));
+    layer->filter_momentum = calloc(c*n*size*size, sizeof(float));
 
-    layer->biases = calloc(n, sizeof(double));
-    layer->bias_updates = calloc(n, sizeof(double));
-    layer->bias_momentum = calloc(n, sizeof(double));
-    double scale = 2./(size*size);
+    layer->biases = calloc(n, sizeof(float));
+    layer->bias_updates = calloc(n, sizeof(float));
+    layer->bias_momentum = calloc(n, sizeof(float));
+    float scale = 2./(size*size);
     for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = rand_normal()*scale;
     for(i = 0; i < n; ++i){
         //layer->biases[i] = rand_normal()*scale + scale;
@@ -50,9 +50,9 @@
     out_h = (h-size)/stride + 1;
     out_w = (w-size)/stride + 1;
 
-    layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(double));
-    layer->output = calloc(out_h * out_w * n, sizeof(double));
-    layer->delta  = calloc(out_h * out_w * n, sizeof(double));
+    layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
+    layer->output = calloc(out_h * out_w * n, sizeof(float));
+    layer->delta  = calloc(out_h * out_w * n, sizeof(float));
     layer->activation = activation;
     layer->out_h = out_h;
     layer->out_w = out_w;
@@ -63,18 +63,18 @@
     return layer;
 }
 
-void forward_convolutional_layer(const convolutional_layer layer, double *in)
+void forward_convolutional_layer(const convolutional_layer layer, float *in)
 {
     int m = layer.n;
     int k = layer.size*layer.size*layer.c;
     int n = ((layer.h-layer.size)/layer.stride + 1)*
             ((layer.w-layer.size)/layer.stride + 1);
 
-    memset(layer.output, 0, m*n*sizeof(double));
+    memset(layer.output, 0, m*n*sizeof(float));
 
-    double *a = layer.filters;
-    double *b = layer.col_image;
-    double *c = layer.output;
+    float *a = layer.filters;
+    float *b = layer.col_image;
+    float *c = layer.output;
 
     im2col_cpu(in,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b);
     gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
@@ -94,7 +94,7 @@
     int i,j;
     int size = layer.out_h*layer.out_w;
     for(i = 0; i < layer.n; ++i){
-        double sum = 0;
+        float sum = 0;
         for(j = 0; j < size; ++j){
             sum += layer.delta[j+i*size];
         }
@@ -111,14 +111,33 @@
     int k = ((layer.h-layer.size)/layer.stride + 1)*
             ((layer.w-layer.size)/layer.stride + 1);
 
-    double *a = layer.delta;
-    double *b = layer.col_image;
-    double *c = layer.filter_updates;
+    float *a = layer.delta;
+    float *b = layer.col_image;
+    float *c = layer.filter_updates;
 
     gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
 }
 
-void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay)
+void backward_convolutional_layer(convolutional_layer layer, float *delta)
+{
+    int m = layer.size*layer.size*layer.c;
+    int k = layer.n;
+    int n = ((layer.h-layer.size)/layer.stride + 1)*
+            ((layer.w-layer.size)/layer.stride + 1);
+
+    float *a = layer.filters;
+    float *b = layer.delta;
+    float *c = layer.col_image;
+
+
+    memset(c, 0, m*n*sizeof(float));
+    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
+
+    memset(delta, 0, layer.h*layer.w*layer.c*sizeof(float));
+    col2im_cpu(c,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, delta);
+}
+
+void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
 {
     int i;
     int size = layer.size*layer.size*layer.c*layer.n;
@@ -133,9 +152,9 @@
 }
 /*
 
-void backward_convolutional_layer2(convolutional_layer layer, double *input, double *delta)
+void backward_convolutional_layer2(convolutional_layer layer, float *input, float *delta)
 {
-    image in_delta = double_to_image(layer.h, layer.w, layer.c, delta);
+    image in_delta = float_to_image(layer.h, layer.w, layer.c, delta);
     image out_delta = get_convolutional_delta(layer);
     int i,j;
     for(i = 0; i < layer.n; ++i){
@@ -156,10 +175,10 @@
 }
 
 
-void learn_convolutional_layer(convolutional_layer layer, double *input)
+void learn_convolutional_layer(convolutional_layer layer, float *input)
 {
     int i;
-    image in_image = double_to_image(layer.h, layer.w, layer.c, input);
+    image in_image = float_to_image(layer.h, layer.w, layer.c, input);
     image out_delta = get_convolutional_delta(layer);
     gradient_delta_convolutional_layer(layer);
     for(i = 0; i < layer.n; ++i){
@@ -168,7 +187,7 @@
     }
 }
 
-void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay)
+void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
 {
     int i,j;
     for(i = 0; i < layer.n; ++i){
@@ -190,21 +209,28 @@
 void test_convolutional_layer()
 {
     convolutional_layer l = *make_convolutional_layer(4,4,1,1,3,1,LINEAR);
-    double input[] =    {1,2,3,4,
+    float input[] =    {1,2,3,4,
                         5,6,7,8,
                         9,10,11,12,
                         13,14,15,16};
-    double filter[] =   {.5, 0, .3,
+    float filter[] =   {.5, 0, .3,
                         0  , 1,  0,
                         .2 , 0,  1};
-    double delta[] =    {1, 2,
+    float delta[] =    {1, 2,
                         3,  4};
+    float in_delta[] = {.5,1,.3,.6,
+                        5,6,7,8,
+                        9,10,11,12,
+                        13,14,15,16};
     l.filters = filter;
     forward_convolutional_layer(l, input);
     l.delta = delta;
     learn_convolutional_layer(l);
-    image filter_updates = double_to_image(3,3,1,l.filter_updates);
+    image filter_updates = float_to_image(3,3,1,l.filter_updates);
     print_image(filter_updates);
+    printf("Delta:\n");
+    backward_convolutional_layer(l, in_delta);
+    pm(4,4,in_delta);
 }
 
 image get_convolutional_filter(convolutional_layer layer, int i)
@@ -212,7 +238,7 @@
     int h = layer.size;
     int w = layer.size;
     int c = layer.c;
-    return double_to_image(h,w,c,layer.filters+i*h*w*c);
+    return float_to_image(h,w,c,layer.filters+i*h*w*c);
 }
 
 void visualize_convolutional_layer(convolutional_layer layer, char *window)
diff --git a/src/convolutional_layer.h b/src/convolutional_layer.h
index e2e6cdc..c4de24e 100644
--- a/src/convolutional_layer.h
+++ b/src/convolutional_layer.h
@@ -10,28 +10,28 @@
     int n;
     int size;
     int stride;
-    double *filters;
-    double *filter_updates;
-    double *filter_momentum;
+    float *filters;
+    float *filter_updates;
+    float *filter_momentum;
 
-    double *biases;
-    double *bias_updates;
-    double *bias_momentum;
+    float *biases;
+    float *bias_updates;
+    float *bias_momentum;
 
-    double *col_image;
-    double *delta;
-    double *output;
+    float *col_image;
+    float *delta;
+    float *output;
 
     ACTIVATION activation;
 } convolutional_layer;
 
 convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation);
-void forward_convolutional_layer(const convolutional_layer layer, double *in);
+void forward_convolutional_layer(const convolutional_layer layer, float *in);
 void learn_convolutional_layer(convolutional_layer layer);
-void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay);
+void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay);
 void visualize_convolutional_layer(convolutional_layer layer, char *window);
 
-//void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta);
+void backward_convolutional_layer(convolutional_layer layer, float *delta);
 
 //void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer);
 //void visualize_convolutional_filters(convolutional_layer layer, char *window);
diff --git a/src/data.c b/src/data.c
index 0b396d7..2c5932b 100644
--- a/src/data.c
+++ b/src/data.c
@@ -19,10 +19,10 @@
     return lines;
 }
 
-void fill_truth(char *path, char **labels, int k, double *truth)
+void fill_truth(char *path, char **labels, int k, float *truth)
 {
     int i;
-    memset(truth, 0, k*sizeof(double));
+    memset(truth, 0, k*sizeof(float));
     for(i = 0; i < k; ++i){
         if(strstr(path, labels[i])){
             truth[i] = 1;
@@ -36,7 +36,7 @@
     data d;
     d.shallow = 0;
     d.X.rows = n;
-    d.X.vals = calloc(d.X.rows, sizeof(double*));
+    d.X.vals = calloc(d.X.rows, sizeof(float*));
     d.y = make_matrix(n, k);
 
     for(i = 0; i < n; ++i){
@@ -106,8 +106,8 @@
     data d;
     d.shallow = 0;
     matrix X = csv_to_matrix(filename);
-    double *truth_1d = pop_column(&X, target);
-    double **truth = one_hot_encode(truth_1d, X.rows, k);
+    float *truth_1d = pop_column(&X, target);
+    float **truth = one_hot_encode(truth_1d, X.rows, k);
     matrix y;
     y.rows = X.rows;
     y.cols = k;
@@ -123,7 +123,7 @@
     int i;
     for(i = d.X.rows-1; i > 0; --i){
         int index = rand()%i;
-        double *swap = d.X.vals[index];
+        float *swap = d.X.vals[index];
         d.X.vals[index] = d.X.vals[i];
         d.X.vals[i] = swap;
 
@@ -156,10 +156,10 @@
     train.X.cols = test.X.cols = d.X.cols;
     train.y.cols = test.y.cols = d.y.cols;
 
-    train.X.vals = calloc(train.X.rows, sizeof(double*));
-    test.X.vals = calloc(test.X.rows, sizeof(double*));
-    train.y.vals = calloc(train.y.rows, sizeof(double*));
-    test.y.vals = calloc(test.y.rows, sizeof(double*));
+    train.X.vals = calloc(train.X.rows, sizeof(float*));
+    test.X.vals = calloc(test.X.rows, sizeof(float*));
+    train.y.vals = calloc(train.y.rows, sizeof(float*));
+    test.y.vals = calloc(test.y.rows, sizeof(float*));
 
     for(i = 0; i < start; ++i){
         train.X.vals[i] = d.X.vals[i];
diff --git a/src/image.c b/src/image.c
index df8e1b8..62ee5f7 100644
--- a/src/image.c
+++ b/src/image.c
@@ -16,7 +16,7 @@
     for(k = 0; k < source.c; ++k){
         for(i = 0; i < source.h; ++i){
             for(j = 0; j < source.w; ++j){
-                double val = get_pixel(source, i,j,k);
+                float val = get_pixel(source, i,j,k);
                 set_pixel(dest, h+i, w+j, k, val);
             }
         }
@@ -45,14 +45,14 @@
 
 void normalize_image(image p)
 {
-    double *min = calloc(p.c, sizeof(double));
-    double *max = calloc(p.c, sizeof(double));
+    float *min = calloc(p.c, sizeof(float));
+    float *max = calloc(p.c, sizeof(float));
     int i,j;
     for(i = 0; i < p.c; ++i) min[i] = max[i] = p.data[i*p.h*p.w];
 
     for(j = 0; j < p.c; ++j){
         for(i = 0; i < p.h*p.w; ++i){
-            double v = p.data[i+j*p.h*p.w];
+            float v = p.data[i+j*p.h*p.w];
             if(v < min[j]) min[j] = v;
             if(v > max[j]) max[j] = v;
         }
@@ -72,17 +72,17 @@
     free(max);
 }
 
-double avg_image_layer(image m, int l)
+float avg_image_layer(image m, int l)
 {
     int i;
-    double sum = 0;
+    float sum = 0;
     for(i = 0; i < m.h*m.w; ++i){
         sum += m.data[l*m.h*m.w + i];
     }
     return sum/(m.h*m.w);
 }
 
-void threshold_image(image p, double t)
+void threshold_image(image p, float t)
 {
     int i;
     for(i = 0; i < p.w*p.h*p.c; ++i){
@@ -93,8 +93,8 @@
 image copy_image(image p)
 {
     image copy = p;
-    copy.data = calloc(p.h*p.w*p.c, sizeof(double));
-    memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(double));
+    copy.data = calloc(p.h*p.w*p.c, sizeof(float));
+    memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(float));
     return copy;
 }
 
@@ -168,11 +168,11 @@
 image make_image(int h, int w, int c)
 {
     image out = make_empty_image(h,w,c);
-    out.data = calloc(h*w*c, sizeof(double));
+    out.data = calloc(h*w*c, sizeof(float));
     return out;
 }
 
-image double_to_image(int h, int w, int c, double *data)
+image float_to_image(int h, int w, int c, float *data)
 {
     image out = make_empty_image(h,w,c);
     out.data = data;
@@ -181,12 +181,12 @@
 
 void zero_image(image m)
 {
-    memset(m.data, 0, m.h*m.w*m.c*sizeof(double));
+    memset(m.data, 0, m.h*m.w*m.c*sizeof(float));
 }
 
 void zero_channel(image m, int c)
 {
-    memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(double));
+    memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(float));
 }
 
 void rotate_image(image m)
@@ -194,7 +194,7 @@
     int i,j;
     for(j = 0; j < m.c; ++j){
         for(i = 0; i < m.h*m.w/2; ++i){
-            double swap = m.data[j*m.h*m.w + i];
+            float swap = m.data[j*m.h*m.w + i];
             m.data[j*m.h*m.w + i] = m.data[j*m.h*m.w + (m.h*m.w-1 - i)];
             m.data[j*m.h*m.w + (m.h*m.w-1 - i)] = swap;
         }
@@ -212,19 +212,19 @@
     return out;
 }
 
-void add_scalar_image(image m, double s)
+void add_scalar_image(image m, float s)
 {
     int i;
     for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s;
 }
 
-void scale_image(image m, double s)
+void scale_image(image m, float s)
 {
     int i;
     for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s;
 }
 
-image make_random_kernel(int size, int c, double scale)
+image make_random_kernel(int size, int c, float scale)
 {
     int pad;
     if((pad=(size%2==0))) ++size;
@@ -280,34 +280,34 @@
     return out;
 }
 
-double get_pixel(image m, int x, int y, int c)
+float get_pixel(image m, int x, int y, int c)
 {
     assert(x < m.h && y < m.w && c < m.c);
     return m.data[c*m.h*m.w + x*m.w + y];
 }
-double get_pixel_extend(image m, int x, int y, int c)
+float get_pixel_extend(image m, int x, int y, int c)
 {
     if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return 0;
     return get_pixel(m, x, y, c);
 }
-void set_pixel(image m, int x, int y, int c, double val)
+void set_pixel(image m, int x, int y, int c, float val)
 {
     assert(x < m.h && y < m.w && c < m.c);
     m.data[c*m.h*m.w + x*m.w + y] = val;
 }
-void set_pixel_extend(image m, int x, int y, int c, double val)
+void set_pixel_extend(image m, int x, int y, int c, float val)
 {
     if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return;
     set_pixel(m, x, y, c, val);
 }
 
-void add_pixel(image m, int x, int y, int c, double val)
+void add_pixel(image m, int x, int y, int c, float val)
 {
     assert(x < m.h && y < m.w && c < m.c);
     m.data[c*m.h*m.w + x*m.w + y] += val;
 }
 
-void add_pixel_extend(image m, int x, int y, int c, double val)
+void add_pixel_extend(image m, int x, int y, int c, float val)
 {
     if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return;
     add_pixel(m, x, y, c, val);
@@ -329,7 +329,7 @@
     }
     for(x = xstart; x < xend; x += stride){
         for(y = ystart; y < yend; y += stride){
-            double sum = 0;
+            float sum = 0;
             for(i = 0; i < kernel.h; ++i){
                 for(j = 0; j < kernel.w; ++j){
                     sum += get_pixel(kernel, i, j, kc)*get_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, mc);
@@ -340,9 +340,9 @@
     }
 }
 
-double single_convolve(image m, image kernel, int x, int y)
+float single_convolve(image m, image kernel, int x, int y)
 {
-    double sum = 0;
+    float sum = 0;
     int i, j, k;
     for(i = 0; i < kernel.h; ++i){
         for(j = 0; j < kernel.w; ++j){
@@ -366,7 +366,7 @@
     int j;
     for(i = 0; i < m.h; i += stride){
         for(j = 0; j < m.w; j += stride){
-            double val = single_convolve(m, kernel, i, j);
+            float val = single_convolve(m, kernel, i, j);
             set_pixel(out, i/stride, j/stride, channel, val);
         }
     }
@@ -380,20 +380,20 @@
     for(k = 0; k < m.c; ++k){
         for(i = 0; i < m.h; ++i){
             for(j = 0; j< m.w; ++j){
-                double val = get_pixel(m, i, j, k);
+                float val = get_pixel(m, i, j, k);
                 set_pixel(out, i*stride, j*stride, k, val);
             }
         }
     }
 }
 
-void single_update(image m, image update, int x, int y, double error)
+void single_update(image m, image update, int x, int y, float error)
 {
     int i, j, k;
     for(i = 0; i < update.h; ++i){
         for(j = 0; j < update.w; ++j){
             for(k = 0; k < update.c; ++k){
-                double val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k);
+                float val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k);
                 add_pixel(update, i, j, k, val*error);
             }
         }
@@ -417,7 +417,7 @@
     }
     for(i = istart; i < iend; i += stride){
         for(j = jstart; j < jend; j += stride){
-            double error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
+            float error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
             single_update(m, update, i, j, error);
         }
     }
@@ -428,13 +428,13 @@
     */
 }
 
-void single_back_convolve(image m, image kernel, int x, int y, double val)
+void single_back_convolve(image m, image kernel, int x, int y, float val)
 {
     int i, j, k;
     for(i = 0; i < kernel.h; ++i){
         for(j = 0; j < kernel.w; ++j){
             for(k = 0; k < kernel.c; ++k){
-                double pval = get_pixel(kernel, i, j, k) * val;
+                float pval = get_pixel(kernel, i, j, k) * val;
                 add_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, k, pval);
             }
         }
@@ -457,7 +457,7 @@
     }
     for(i = istart; i < iend; i += stride){
         for(j = jstart; j < jend; j += stride){
-            double val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
+            float val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
             single_back_convolve(m, kernel, i, j, val);
         }
     }
diff --git a/src/image.h b/src/image.h
index 1865857..72c4b2c 100644
--- a/src/image.h
+++ b/src/image.h
@@ -7,18 +7,18 @@
     int h;
     int w;
     int c;
-    double *data;
+    float *data;
 } image;
 
-void scale_image(image m, double s);
-void add_scalar_image(image m, double s);
+void scale_image(image m, float s);
+void add_scalar_image(image m, float s);
 void normalize_image(image p);
 void z_normalize_image(image p);
-void threshold_image(image p, double t);
+void threshold_image(image p, float t);
 void zero_image(image m);
 void rotate_image(image m);
 void subtract_image(image a, image b);
-double avg_image_layer(image m, int l);
+float avg_image_layer(image m, int l);
 void embed_image(image source, image dest, int h, int w);
 image collapse_image_layers(image source, int border);
 
@@ -30,14 +30,14 @@
 image make_image(int h, int w, int c);
 image make_empty_image(int h, int w, int c);
 image make_random_image(int h, int w, int c);
-image make_random_kernel(int size, int c, double scale);
-image double_to_image(int h, int w, int c, double *data);
+image make_random_kernel(int size, int c, float scale);
+image float_to_image(int h, int w, int c, float *data);
 image copy_image(image p);
 image load_image(char *filename);
 
-double get_pixel(image m, int x, int y, int c);
-double get_pixel_extend(image m, int x, int y, int c);
-void set_pixel(image m, int x, int y, int c, double val);
+float get_pixel(image m, int x, int y, int c);
+float get_pixel_extend(image m, int x, int y, int c);
+void set_pixel(image m, int x, int y, int c, float val);
 
 
 image get_image_layer(image m, int l);
diff --git a/src/matrix.c b/src/matrix.c
index 68e6f8d..96bd332 100644
--- a/src/matrix.c
+++ b/src/matrix.c
@@ -13,7 +13,7 @@
     free(m.vals);
 }
 
-double matrix_accuracy(matrix truth, matrix guess)
+float matrix_accuracy(matrix truth, matrix guess)
 {
     int k = truth.cols;
     int i;
@@ -22,7 +22,7 @@
         int class = max_index(guess.vals[i], k);
         if(truth.vals[i][class]) ++count;
     }
-    return (double)count/truth.rows;
+    return (float)count/truth.rows;
 }
 
 void matrix_add_matrix(matrix from, matrix to)
@@ -42,9 +42,9 @@
     matrix m;
     m.rows = rows;
     m.cols = cols;
-    m.vals = calloc(m.rows, sizeof(double *));
+    m.vals = calloc(m.rows, sizeof(float *));
     for(i = 0; i < m.rows; ++i){
-        m.vals[i] = calloc(m.cols, sizeof(double));
+        m.vals[i] = calloc(m.cols, sizeof(float));
     }
     return m;
 }
@@ -55,7 +55,7 @@
     matrix h;
     h.rows = n;
     h.cols = m->cols;
-    h.vals = calloc(h.rows, sizeof(double *));
+    h.vals = calloc(h.rows, sizeof(float *));
     for(i = 0; i < n; ++i){
         int index = rand()%m->rows;
         h.vals[i] = m->vals[index];
@@ -64,9 +64,9 @@
     return h;
 }
 
-double *pop_column(matrix *m, int c)
+float *pop_column(matrix *m, int c)
 {
-    double *col = calloc(m->rows, sizeof(double));
+    float *col = calloc(m->rows, sizeof(float));
     int i, j;
     for(i = 0; i < m->rows; ++i){
         col[i] = m->vals[i][c];
@@ -90,18 +90,18 @@
 
 	int n = 0;
 	int size = 1024;
-	m.vals = calloc(size, sizeof(double*));
+	m.vals = calloc(size, sizeof(float*));
 	while((line = fgetl(fp))){
         if(m.cols == -1) m.cols = count_fields(line);
 		if(n == size){
 			size *= 2;
-			m.vals = realloc(m.vals, size*sizeof(double*));
+			m.vals = realloc(m.vals, size*sizeof(float*));
 		}
 		m.vals[n] = parse_fields(line, m.cols);
 		free(line);
 		++n;
 	}
-	m.vals = realloc(m.vals, n*sizeof(double*));
+	m.vals = realloc(m.vals, n*sizeof(float*));
     m.rows = n;
 	return m;
 }
diff --git a/src/matrix.h b/src/matrix.h
index 098eb9e..01d825d 100644
--- a/src/matrix.h
+++ b/src/matrix.h
@@ -2,7 +2,7 @@
 #define MATRIX_H
 typedef struct matrix{
     int rows, cols;
-    double **vals;
+    float **vals;
 } matrix;
 
 matrix make_matrix(int rows, int cols);
@@ -11,9 +11,9 @@
 
 matrix csv_to_matrix(char *filename);
 matrix hold_out_matrix(matrix *m, int n);
-double matrix_accuracy(matrix truth, matrix guess);
+float matrix_accuracy(matrix truth, matrix guess);
 void matrix_add_matrix(matrix from, matrix to);
 
-double *pop_column(matrix *m, int c);
+float *pop_column(matrix *m, int c);
 
 #endif
diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c
index ccf9bee..8c409b9 100644
--- a/src/maxpool_layer.c
+++ b/src/maxpool_layer.c
@@ -6,7 +6,7 @@
     int h = (layer.h-1)/layer.stride + 1;
     int w = (layer.w-1)/layer.stride + 1;
     int c = layer.c;
-    return double_to_image(h,w,c,layer.output);
+    return float_to_image(h,w,c,layer.output);
 }
 
 image get_maxpool_delta(maxpool_layer layer)
@@ -14,7 +14,7 @@
     int h = (layer.h-1)/layer.stride + 1;
     int w = (layer.w-1)/layer.stride + 1;
     int c = layer.c;
-    return double_to_image(h,w,c,layer.delta);
+    return float_to_image(h,w,c,layer.delta);
 }
 
 maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride)
@@ -25,41 +25,41 @@
     layer->w = w;
     layer->c = c;
     layer->stride = stride;
-    layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double));
-    layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double));
+    layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float));
+    layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float));
     return layer;
 }
 
-void forward_maxpool_layer(const maxpool_layer layer, double *in)
+void forward_maxpool_layer(const maxpool_layer layer, float *in)
 {
-    image input = double_to_image(layer.h, layer.w, layer.c, in);
+    image input = float_to_image(layer.h, layer.w, layer.c, in);
     image output = get_maxpool_image(layer);
     int i,j,k;
     for(i = 0; i < output.h*output.w*output.c; ++i) output.data[i] = -DBL_MAX;
     for(k = 0; k < input.c; ++k){
         for(i = 0; i < input.h; ++i){
             for(j = 0; j < input.w; ++j){
-                double val = get_pixel(input, i, j, k);
-                double cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
+                float val = get_pixel(input, i, j, k);
+                float cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
                 if(val > cur) set_pixel(output, i/layer.stride, j/layer.stride, k, val);
             }
         }
     }
 }
 
-void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta)
+void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta)
 {
-    image input = double_to_image(layer.h, layer.w, layer.c, in);
-    image input_delta = double_to_image(layer.h, layer.w, layer.c, delta);
+    image input = float_to_image(layer.h, layer.w, layer.c, in);
+    image input_delta = float_to_image(layer.h, layer.w, layer.c, delta);
     image output_delta = get_maxpool_delta(layer);
     image output = get_maxpool_image(layer);
     int i,j,k;
     for(k = 0; k < input.c; ++k){
         for(i = 0; i < input.h; ++i){
             for(j = 0; j < input.w; ++j){
-                double val = get_pixel(input, i, j, k);
-                double cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
-                double d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k);
+                float val = get_pixel(input, i, j, k);
+                float cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
+                float d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k);
                 if(val == cur) {
                     set_pixel(input_delta, i, j, k, d);
                 }
diff --git a/src/maxpool_layer.h b/src/maxpool_layer.h
index 0afe68a..27d6f55 100644
--- a/src/maxpool_layer.h
+++ b/src/maxpool_layer.h
@@ -6,14 +6,14 @@
 typedef struct {
     int h,w,c;
     int stride;
-    double *delta;
-    double *output;
+    float *delta;
+    float *output;
 } maxpool_layer;
 
 image get_maxpool_image(maxpool_layer layer);
 maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride);
-void forward_maxpool_layer(const maxpool_layer layer, double *in);
-void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta);
+void forward_maxpool_layer(const maxpool_layer layer, float *in);
+void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta);
 
 #endif
 
diff --git a/src/mini_blas.c b/src/mini_blas.c
index 3af36e5..b9a4304 100644
--- a/src/mini_blas.c
+++ b/src/mini_blas.c
@@ -1,8 +1,10 @@
 
 #include <stdlib.h>
+#include <stdio.h>
 #include <math.h>
+#include <time.h>
 
-void pm(int M, int N, double *A)
+void pm(int M, int N, float *A)
 {
     int i,j;
     for(i =0 ; i < M; ++i){
@@ -14,28 +16,37 @@
     printf("\n");
 }
 
-void gemm(int TA, int TB, int M, int N, int K, double ALPHA, 
-                    double *A, int lda, 
-                    double *B, int ldb,
-                    double BETA,
-                    double *C, int ldc)
+void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
+                    float *A, int lda, 
+                    float *B, int ldb,
+                    float BETA,
+                    float *C, int ldc)
 {
-    // Assume TA = 0, beta = 1 LULZ
+    // Assume beta = 1 LULZ
     int i,j,k;
     if(TB && !TA){
         for(i = 0; i < M; ++i){
             for(j = 0; j < N; ++j){
-                register double sum = 0;
+                register float sum = 0;
                 for(k = 0; k < K; ++k){
                     sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
                 }
                 C[i*ldc+j] += sum;
             }
         }
+    }else if(TA && !TB){
+        for(i = 0; i < M; ++i){
+            for(k = 0; k < K; ++k){
+                register float A_PART = ALPHA*A[k*lda+i];
+                for(j = 0; j < N; ++j){
+                    C[i*ldc+j] += A_PART*B[k*ldb+j];
+                }
+            }
+        }
     }else{
         for(i = 0; i < M; ++i){
             for(k = 0; k < K; ++k){
-                register double A_PART = ALPHA*A[i*lda+k];
+                register float A_PART = ALPHA*A[i*lda+k];
                 for(j = 0; j < N; ++j){
                     C[i*ldc+j] += A_PART*B[k*ldb+j];
                 }
@@ -44,7 +55,7 @@
     }
 }
 
-void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix)
+void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix)
 {
     int i;
     int mc = c;
@@ -64,7 +75,7 @@
         matrix[i] = image[pc*h*w+ph*w+pw];
     }
 }
-void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix)
+void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix)
 {
     int b,p;
     int blocks = ((h-size)/stride+1)*((w-size)/stride+1);
@@ -84,9 +95,9 @@
 }
 
 //From Berkeley Vision's Caffe!
-void im2col_cpu(double* data_im, const int channels,
+void im2col_cpu(float* data_im, const int channels,
         const int height, const int width, const int ksize, const int stride,
-        double* data_col) 
+        float* data_col) 
 {
     int c,h,w;
     int height_col = (height - ksize) / stride + 1;
@@ -106,3 +117,59 @@
     }
 }
 
+void col2im_cpu(float* data_col, const int channels,
+        const int height, const int width, const int ksize, const int stride,
+        float* data_im) 
+{
+    int c,h,w;
+    int height_col = (height - ksize) / stride + 1;
+    int width_col = (width - ksize) / stride + 1;
+    int channels_col = channels * ksize * ksize;
+    for ( c = 0; c < channels_col; ++c) {
+        int w_offset = c % ksize;
+        int h_offset = (c / ksize) % ksize;
+        int c_im = c / ksize / ksize;
+        for ( h = 0; h < height_col; ++h) {
+            for ( w = 0; w < width_col; ++w) {
+                data_im[(c_im * height + h * stride + h_offset) * width
+                    + w * stride + w_offset]+= data_col[(c * height_col + h) * width_col + w];
+            }
+        }
+    }
+}
+
+float *random_matrix(int rows, int cols)
+{
+    int i;
+    float *m = calloc(rows*cols, sizeof(float));
+    for(i = 0; i < rows*cols; ++i){
+        m[i] = (float)rand()/RAND_MAX;
+    }
+    return m;
+}
+
+void time_random_matrix(int TA, int TB, int m, int k, int n)
+{
+    float *a = random_matrix(m,k);
+    float *b = random_matrix(k,n);
+    float *c = random_matrix(m,n);
+    int i;
+    clock_t start = clock(), end;
+    for(i = 0; i<1000; ++i){
+        gemm(TA,TB,m,n,k,1,a,k,b,n,1,c,n);
+    }
+    end = clock();
+    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (double)(end-start)/CLOCKS_PER_SEC);
+}
+
+void test_blas()
+{
+    time_random_matrix(0,0,100,100,100); 
+    time_random_matrix(1,0,100,100,100); 
+    time_random_matrix(0,1,100,100,100); 
+
+    time_random_matrix(0,1,1000,100,100); 
+    time_random_matrix(1,0,1000,100,100); 
+
+}
+
diff --git a/src/mini_blas.h b/src/mini_blas.h
index 46a37d3..ff82a60 100644
--- a/src/mini_blas.h
+++ b/src/mini_blas.h
@@ -1,11 +1,15 @@
-void pm(int M, int N, double *A);
-void gemm(int TA, int TB, int M, int N, int K, double ALPHA, 
-                    double *A, int lda, 
-                    double *B, int ldb,
-                    double BETA,
-                    double *C, int ldc);
-void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix);
-void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix);
-void im2col_cpu(double* data_im, const int channels,
+void pm(int M, int N, float *A);
+void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
+                    float *A, int lda, 
+                    float *B, int ldb,
+                    float BETA,
+                    float *C, int ldc);
+void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix);
+void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix);
+void im2col_cpu(float* data_im, const int channels,
         const int height, const int width, const int ksize, const int stride,
-        double* data_col);
+        float* data_col);
+void col2im_cpu(float* data_col, const int channels,
+        const int height, const int width, const int ksize, const int stride,
+        float* data_im);
+void test_blas();
diff --git a/src/network.c b/src/network.c
index 2ce13d8..29e22e4 100644
--- a/src/network.c
+++ b/src/network.c
@@ -21,7 +21,7 @@
     return net;
 }
 
-void forward_network(network net, double *input)
+void forward_network(network net, float *input)
 {
     int i;
     for(i = 0; i < net.n; ++i){
@@ -48,7 +48,7 @@
     }
 }
 
-void update_network(network net, double step, double momentum, double decay)
+void update_network(network net, float step, float momentum, float decay)
 {
     int i;
     for(i = 0; i < net.n; ++i){
@@ -69,7 +69,7 @@
     }
 }
 
-double *get_network_output_layer(network net, int i)
+float *get_network_output_layer(network net, int i)
 {
     if(net.types[i] == CONVOLUTIONAL){
         convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -86,12 +86,12 @@
     }
     return 0;
 }
-double *get_network_output(network net)
+float *get_network_output(network net)
 {
     return get_network_output_layer(net, net.n-1);
 }
 
-double *get_network_delta_layer(network net, int i)
+float *get_network_delta_layer(network net, int i)
 {
     if(net.types[i] == CONVOLUTIONAL){
         convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -109,16 +109,16 @@
     return 0;
 }
 
-double *get_network_delta(network net)
+float *get_network_delta(network net)
 {
     return get_network_delta_layer(net, net.n-1);
 }
 
-double calculate_error_network(network net, double *truth)
+float calculate_error_network(network net, float *truth)
 {
-    double sum = 0;
-    double *delta = get_network_delta(net);
-    double *out = get_network_output(net);
+    float sum = 0;
+    float *delta = get_network_delta(net);
+    float *out = get_network_output(net);
     int i, k = get_network_output_size(net);
     for(i = 0; i < k; ++i){
         delta[i] = truth[i] - out[i];
@@ -129,17 +129,17 @@
 
 int get_predicted_class_network(network net)
 {
-    double *out = get_network_output(net);
+    float *out = get_network_output(net);
     int k = get_network_output_size(net);
     return max_index(out, k);
 }
 
-double backward_network(network net, double *input, double *truth)
+float backward_network(network net, float *input, float *truth)
 {
-    double error = calculate_error_network(net, truth);
+    float error = calculate_error_network(net, truth);
     int i;
-    double *prev_input;
-    double *prev_delta;
+    float *prev_input;
+    float *prev_delta;
     for(i = net.n-1; i >= 0; --i){
         if(i == 0){
             prev_input = input;
@@ -152,7 +152,7 @@
             convolutional_layer layer = *(convolutional_layer *)net.layers[i];
             learn_convolutional_layer(layer);
             //learn_convolutional_layer(layer);
-            //if(i != 0) backward_convolutional_layer(layer, prev_input, prev_delta);
+            if(i != 0) backward_convolutional_layer(layer, prev_delta);
         }
         else if(net.types[i] == MAXPOOL){
             maxpool_layer layer = *(maxpool_layer *)net.layers[i];
@@ -171,49 +171,49 @@
     return error;
 }
 
-double train_network_datum(network net, double *x, double *y, double step, double momentum, double decay)
+float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay)
 {
         forward_network(net, x);
         int class = get_predicted_class_network(net);
-        double error = backward_network(net, x, y);
+        float error = backward_network(net, x, y);
         update_network(net, step, momentum, decay);
         //return (y[class]?1:0);
         return error;
 }
 
-double train_network_sgd(network net, data d, int n, double step, double momentum,double decay)
+float train_network_sgd(network net, data d, int n, float step, float momentum,float decay)
 {
     int i;
-    double error = 0;
+    float error = 0;
     for(i = 0; i < n; ++i){
         int index = rand()%d.X.rows;
         error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay);
         //if((i+1)%10 == 0){
-        //    printf("%d: %f\n", (i+1), (double)correct/(i+1));
+        //    printf("%d: %f\n", (i+1), (float)correct/(i+1));
         //}
     }
     return error/n;
 }
-double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
+float train_network_batch(network net, data d, int n, float step, float momentum,float decay)
 {
     int i;
     int correct = 0;
     for(i = 0; i < n; ++i){
         int index = rand()%d.X.rows;
-        double *x = d.X.vals[index];
-        double *y = d.y.vals[index];
+        float *x = d.X.vals[index];
+        float *y = d.y.vals[index];
         forward_network(net, x);
         int class = get_predicted_class_network(net);
         backward_network(net, x, y);
         correct += (y[class]?1:0);
     }
     update_network(net, step, momentum, decay);
-    return (double)correct/n;
+    return (float)correct/n;
 
 }
 
 
-void train_network(network net, data d, double step, double momentum, double decay)
+void train_network(network net, data d, float step, float momentum, float decay)
 {
     int i;
     int correct = 0;
@@ -226,7 +226,7 @@
     }
     visualize_network(net);
     cvWaitKey(100);
-    printf("Accuracy: %f\n", (double)correct/d.X.rows);
+    printf("Accuracy: %f\n", (float)correct/d.X.rows);
 }
 
 int get_network_output_size_layer(network net, int i)
@@ -294,10 +294,10 @@
     } 
 }
 
-double *network_predict(network net, double *input)
+float *network_predict(network net, float *input)
 {
     forward_network(net, input);
-    double *out = get_network_output(net);
+    float *out = get_network_output(net);
     return out;
 }
 
@@ -307,7 +307,7 @@
     int k = get_network_output_size(net);
     matrix pred = make_matrix(test.X.rows, k);
     for(i = 0; i < test.X.rows; ++i){
-        double *out = network_predict(net, test.X.vals[i]);
+        float *out = network_predict(net, test.X.vals[i]);
         for(j = 0; j < k; ++j){
             pred.vals[i][j] = out[j];
         }
@@ -319,7 +319,7 @@
 {
     int i,j;
     for(i = 0; i < net.n; ++i){
-        double *output = 0;
+        float *output = 0;
         int n = 0;
         if(net.types[i] == CONVOLUTIONAL){
             convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -343,8 +343,8 @@
             output = layer.output;
             n = layer.inputs;
         }
-        double mean = mean_array(output, n);
-        double vari = variance_array(output, n);
+        float mean = mean_array(output, n);
+        float vari = variance_array(output, n);
         fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari);
         if(n > 100) n = 100;
         for(j = 0; j < n; ++j) fprintf(stderr, "%f, ", output[j]);
@@ -353,10 +353,10 @@
     }
 }
 
-double network_accuracy(network net, data d)
+float network_accuracy(network net, data d)
 {
     matrix guess = network_predict_data(net, d);
-    double acc = matrix_accuracy(d.y, guess);
+    float acc = matrix_accuracy(d.y, guess);
     free_matrix(guess);
     return acc;
 }
diff --git a/src/network.h b/src/network.h
index fa109dd..17cc10b 100644
--- a/src/network.h
+++ b/src/network.h
@@ -17,22 +17,22 @@
     void **layers;
     LAYER_TYPE *types;
     int outputs;
-    double *output;
+    float *output;
 } network;
 
 network make_network(int n);
-void forward_network(network net, double *input);
-double backward_network(network net, double *input, double *truth);
-void update_network(network net, double step, double momentum, double decay);
-double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
-double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
-void train_network(network net, data d, double step, double momentum, double decay);
+void forward_network(network net, float *input);
+float backward_network(network net, float *input, float *truth);
+void update_network(network net, float step, float momentum, float decay);
+float train_network_sgd(network net, data d, int n, float step, float momentum,float decay);
+float train_network_batch(network net, data d, int n, float step, float momentum,float decay);
+void train_network(network net, data d, float step, float momentum, float decay);
 matrix network_predict_data(network net, data test);
-double network_accuracy(network net, data d);
-double *get_network_output(network net);
-double *get_network_output_layer(network net, int i);
-double *get_network_delta_layer(network net, int i);
-double *get_network_delta(network net);
+float network_accuracy(network net, data d);
+float *get_network_output(network net);
+float *get_network_output_layer(network net, int i);
+float *get_network_delta_layer(network net, int i);
+float *get_network_delta(network net);
 int get_network_output_size_layer(network net, int i);
 int get_network_output_size(network net);
 image get_network_image(network net);
diff --git a/src/option_list.c b/src/option_list.c
index 1b32ebb..7902cd9 100644
--- a/src/option_list.c
+++ b/src/option_list.c
@@ -59,7 +59,7 @@
     return def;
 }
 
-double option_find_double(list *l, char *key, double def)
+float option_find_float(list *l, char *key, float def)
 {
     char *v = option_find(l, key);
     if(v) return atof(v);
diff --git a/src/option_list.h b/src/option_list.h
index 0270465..60e37fe 100644
--- a/src/option_list.h
+++ b/src/option_list.h
@@ -6,7 +6,7 @@
 char *option_find(list *l, char *key);
 char *option_find_str(list *l, char *key, char *def);
 int option_find_int(list *l, char *key, int def);
-double option_find_double(list *l, char *key, double def);
+float option_find_float(list *l, char *key, float def);
 void option_unused(list *l);
 
 #endif
diff --git a/src/softmax_layer.c b/src/softmax_layer.c
index b213e5b..1e01bd2 100644
--- a/src/softmax_layer.c
+++ b/src/softmax_layer.c
@@ -8,15 +8,16 @@
     fprintf(stderr, "Softmax Layer: %d inputs\n", inputs);
     softmax_layer *layer = calloc(1, sizeof(softmax_layer));
     layer->inputs = inputs;
-    layer->output = calloc(inputs, sizeof(double));
-    layer->delta = calloc(inputs, sizeof(double));
+    layer->output = calloc(inputs, sizeof(float));
+    layer->delta = calloc(inputs, sizeof(float));
     return layer;
 }
 
-void forward_softmax_layer(const softmax_layer layer, double *input)
+/* UNSTABLE!
+void forward_softmax_layer(const softmax_layer layer, float *input)
 {
     int i;
-    double sum = 0;
+    float sum = 0;
     for(i = 0; i < layer.inputs; ++i){
         sum += exp(input[i]);
     }
@@ -24,8 +25,25 @@
         layer.output[i] = exp(input[i])/sum;
     }
 }
+*/
+void forward_softmax_layer(const softmax_layer layer, float *input)
+{
+    int i;
+    float sum = 0;
+    float largest = 0;
+    for(i = 0; i < layer.inputs; ++i){
+        if(input[i] > largest) largest = input[i];
+    }
+    for(i = 0; i < layer.inputs; ++i){
+        sum += exp(input[i]-largest);
+    }
+    sum = largest+log(sum);
+    for(i = 0; i < layer.inputs; ++i){
+        layer.output[i] = exp(input[i]-sum);
+    }
+}
 
-void backward_softmax_layer(const softmax_layer layer, double *input, double *delta)
+void backward_softmax_layer(const softmax_layer layer, float *input, float *delta)
 {
     int i;
     for(i = 0; i < layer.inputs; ++i){
diff --git a/src/softmax_layer.h b/src/softmax_layer.h
index 1a0d760..bfcd390 100644
--- a/src/softmax_layer.h
+++ b/src/softmax_layer.h
@@ -3,12 +3,12 @@
 
 typedef struct {
     int inputs;
-    double *delta;
-    double *output;
+    float *delta;
+    float *output;
 } softmax_layer;
 
 softmax_layer *make_softmax_layer(int inputs);
-void forward_softmax_layer(const softmax_layer layer, double *input);
-void backward_softmax_layer(const softmax_layer layer, double *input, double *delta);
+void forward_softmax_layer(const softmax_layer layer, float *input);
+void backward_softmax_layer(const softmax_layer layer, float *input, float *delta);
 
 #endif
diff --git a/src/tests.c b/src/tests.c
index af22ddb..00cd1a1 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -14,6 +14,9 @@
 #include <stdlib.h>
 #include <stdio.h>
 
+#define _GNU_SOURCE
+#include <fenv.h>
+
 void test_convolve()
 {
     image dog = load_image("dog.jpg");
@@ -26,7 +29,7 @@
         convolve(dog, kernel, 1, 0, edge, 1);
     }
     end = clock();
-    printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+    printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
     show_image_layers(edge, "Test Convolve");
 }
 
@@ -38,11 +41,11 @@
     int size = 11;
     int stride = 4;
     int n = 40;
-    double *filters = make_random_image(size, size, dog.c*n).data;
+    float *filters = make_random_image(size, size, dog.c*n).data;
 
     int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1);
     int mh = (size*size*dog.c);
-    double *matrix = calloc(mh*mw, sizeof(double));
+    float *matrix = calloc(mh*mw, sizeof(float));
 
     image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n);
 
@@ -54,7 +57,7 @@
         gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
     }
     end = clock();
-    printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+    printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
     show_image_layers(edge, "Test Convolve");
     cvWaitKey(0);
 }
@@ -72,11 +75,11 @@
     int n = 1;
     int stride = 1;
     int size = 3;
-    double eps = .00000001;
+    float eps = .00000001;
     image test = make_random_image(5,5, 1);
     convolutional_layer layer = *make_convolutional_layer(test.h,test.w,test.c, n, size, stride, RELU);
     image out = get_convolutional_image(layer);
-    double **jacobian = calloc(test.h*test.w*test.c, sizeof(double));
+    float **jacobian = calloc(test.h*test.w*test.c, sizeof(float));
     
     forward_convolutional_layer(layer, test.data);
     image base = copy_image(out);
@@ -90,19 +93,19 @@
         jacobian[i] = partial.data;
         test.data[i] -= eps;
     }
-    double **jacobian2 = calloc(out.h*out.w*out.c, sizeof(double));
+    float **jacobian2 = calloc(out.h*out.w*out.c, sizeof(float));
     image in_delta = make_image(test.h, test.w, test.c);
     image out_delta = get_convolutional_delta(layer);
     for(i = 0; i < out.h*out.w*out.c; ++i){
         out_delta.data[i] = 1;
-        //backward_convolutional_layer(layer, test.data, in_delta.data);
+        backward_convolutional_layer(layer, in_delta.data);
         image partial = copy_image(in_delta);
         jacobian2[i] = partial.data;
         out_delta.data[i] = 0;
     }
     int j;
-    double *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double));
-    double *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double));
+    float *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float));
+    float *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float));
     for(i = 0; i < test.h*test.w*test.c; ++i){
         for(j =0 ; j < out.h*out.w*out.c; ++j){
             j1[i*out.h*out.w*out.c + j] = jacobian[i][j];
@@ -112,12 +115,11 @@
     }
 
 
-    image mj1 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1);
-    image mj2 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2);
+    image mj1 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1);
+    image mj2 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2);
     printf("%f %f\n", avg_image_layer(mj1,0), avg_image_layer(mj2,0));
     show_image(mj1, "forward jacobian");
     show_image(mj2, "backward jacobian");
-    
 }
 
 void test_load()
@@ -145,7 +147,7 @@
         rotate_image(dog);
     }
     end = clock();
-    printf("Rotations: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+    printf("Rotations: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
     show_image(dog, "Test Rotate");
 
     image random = make_random_image(3,3,3);
@@ -159,18 +161,18 @@
 void test_parser()
 {
     network net = parse_network_cfg("test_parser.cfg");
-    double input[1];
+    float input[1];
     int count = 0;
         
-    double avgerr = 0;
+    float avgerr = 0;
     while(++count < 100000000){
-        double v = ((double)rand()/RAND_MAX);
-        double truth = v*v;
+        float v = ((float)rand()/RAND_MAX);
+        float truth = v*v;
         input[0] = v;
         forward_network(net, input);
-        double *out = get_network_output(net);
-        double *delta = get_network_delta(net);
-        double err = pow((out[0]-truth),2.);
+        float *out = get_network_output(net);
+        float *delta = get_network_delta(net);
+        float err = pow((out[0]-truth),2.);
         avgerr = .99 * avgerr + .01 * err;
         if(count % 1000000 == 0) printf("%f %f :%f AVG %f \n", truth, out[0], err, avgerr);
         delta[0] = truth - out[0];
@@ -192,9 +194,9 @@
     srand(0);
     int i = 0;
     char *labels[] = {"cat","dog"};
-    double lr = .00001;
-    double momentum = .9;
-    double decay = 0.01;
+    float lr = .00001;
+    float momentum = .9;
+    float decay = 0.01;
     while(i++ < 1000 || 1){
         data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
         train_network(net, train, lr, momentum, decay);
@@ -207,32 +209,33 @@
 {
     srand(444444);
     srand(888888);
-    network net = parse_network_cfg("nist_basic.cfg");
+    network net = parse_network_cfg("nist.cfg");
     data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
     data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
     normalize_data_rows(train);
     normalize_data_rows(test);
     //randomize_data(train);
     int count = 0;
-    double lr = .0005;
-    double momentum = .9;
-    double decay = 0.01;
+    float lr = .0005;
+    float momentum = .9;
+    float decay = 0.01;
     clock_t start = clock(), end;
     while(++count <= 100){
-        visualize_network(net);
-        double loss = train_network_sgd(net, train, 10000, lr, momentum, decay);
+        //visualize_network(net);
+        float loss = train_network_sgd(net, train, 1000, lr, momentum, decay);
         printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, loss, lr, momentum, decay);
         end = clock();
-        printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+        printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
         start=end;
         cvWaitKey(100);
         //lr /= 2; 
         if(count%5 == 0){
-            double train_acc = network_accuracy(net, train);
+            float train_acc = network_accuracy(net, train);
             fprintf(stderr, "\nTRAIN: %f\n", train_acc);
-            double test_acc = network_accuracy(net, test);
+            float test_acc = network_accuracy(net, test);
             fprintf(stderr, "TEST: %f\n\n", test_acc);
             printf("%d, %f, %f\n", count, train_acc, test_acc);
+            lr *= .5;
         }
     }
 }
@@ -253,24 +256,24 @@
     int n = 30;
     for(i = 0; i < n; ++i){
         int count = 0;
-        double lr = .0005;
-        double momentum = .9;
-        double decay = .01;
+        float lr = .0005;
+        float momentum = .9;
+        float decay = .01;
         network net = parse_network_cfg("nist.cfg");
         while(++count <= 15){
-            double acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
+            float acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
             printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay );
             lr /= 2; 
         }
         matrix partial = network_predict_data(net, test);
-        double acc = matrix_accuracy(test.y, partial);
+        float acc = matrix_accuracy(test.y, partial);
         printf("Model Accuracy: %lf\n", acc);
         matrix_add_matrix(partial, prediction);
         acc = matrix_accuracy(test.y, prediction);
         printf("Current Ensemble Accuracy: %lf\n", acc);
         free_matrix(partial);
     }
-    double acc = matrix_accuracy(test.y, prediction);
+    float acc = matrix_accuracy(test.y, prediction);
     printf("Full Ensemble Accuracy: %lf\n", acc);
 }
 
@@ -279,19 +282,19 @@
     network net = parse_network_cfg("connected.cfg");
     matrix m = csv_to_matrix("train.csv");
     //matrix ho = hold_out_matrix(&m, 2500);
-    double *truth = pop_column(&m, 0);
-    //double *ho_truth = pop_column(&ho, 0);
+    float *truth = pop_column(&m, 0);
+    //float *ho_truth = pop_column(&ho, 0);
     int i;
     clock_t start = clock(), end;
     int count = 0;
     while(++count <= 300){
         for(i = 0; i < m.rows; ++i){
             int index = rand()%m.rows;
-            //image p = double_to_image(1690,1,1,m.vals[index]);
+            //image p = float_to_image(1690,1,1,m.vals[index]);
             //normalize_image(p);
             forward_network(net, m.vals[index]);
-            double *out = get_network_output(net);
-            double *delta = get_network_delta(net);
+            float *out = get_network_output(net);
+            float *delta = get_network_delta(net);
             //printf("%f\n", out[0]);
             delta[0] = truth[index] - out[0];
             // printf("%f\n", delta[0]);
@@ -299,8 +302,8 @@
             //backward_network(net, m.vals[index], );
             update_network(net, .00001, 0,0);
         }
-        //double test_acc = error_network(net, m, truth);
-        //double valid_acc = error_network(net, ho, ho_truth);
+        //float test_acc = error_network(net, m, truth);
+        //float valid_acc = error_network(net, ho, ho_truth);
         //printf("%f, %f\n", test_acc, valid_acc);
         //fprintf(stderr, "%5d: %f Valid: %f\n",count, test_acc, valid_acc);
         //if(valid_acc > .70) break;
@@ -311,12 +314,12 @@
     truth = pop_column(&test, 0);
     for(i = 0; i < test.rows; ++i){
         forward_network(net, test.vals[i]);
-        double *out = get_network_output(net);
+        float *out = get_network_output(net);
         if(fabs(out[0]) < .5) fprintf(fp, "0\n");
         else fprintf(fp, "1\n");
     }
     fclose(fp);
-    printf("Neural Net Learning: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
+    printf("Neural Net Learning: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
 }
 
 void test_split()
@@ -326,30 +329,6 @@
     printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows);
 }
 
-double *random_matrix(int rows, int cols)
-{
-    int i, j;
-    double *m = calloc(rows*cols, sizeof(double));
-    for(i = 0; i < rows; ++i){
-        for(j = 0; j < cols; ++j){
-            m[i*cols+j] = (double)rand()/RAND_MAX;
-        }
-    }
-    return m;
-}
-
-void test_blas()
-{
-    int m = 1000, n = 1000, k = 1000;
-    double *a = random_matrix(m,k);
-    double *b = random_matrix(k,n);
-    double *c = random_matrix(m,n);
-    int i;
-    for(i = 0; i<1000; ++i){
-        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
-    }
-}
-
 void test_im2row()
 {
     int h = 20;
@@ -362,16 +341,18 @@
     int mw = ((h-size)/stride+1)*((w-size)/stride+1);
     int mh = (size*size*c);
     int msize = mc*mw*mh;
-    double *matrix = calloc(msize, sizeof(double));
+    float *matrix = calloc(msize, sizeof(float));
     int i;
     for(i = 0; i < 1000; ++i){
         im2col_cpu(test.data,  c,  h,  w,  size,  stride, matrix);
-        image render = double_to_image(mh, mw, mc, matrix);
+        image render = float_to_image(mh, mw, mc, matrix);
     }
 }
 
 int main()
 {
+    //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW);
+
     //test_blas();
     //test_convolve_matrix();
     //    test_im2row();
diff --git a/src/utils.c b/src/utils.c
index 5180fe6..41ee768 100644
--- a/src/utils.c
+++ b/src/utils.c
@@ -123,9 +123,9 @@
 	return count;
 }
 
-double *parse_fields(char *line, int n)
+float *parse_fields(char *line, int n)
 {
-	double *field = calloc(n, sizeof(double));
+	float *field = calloc(n, sizeof(float));
 	char *c, *p, *end;
 	int count = 0;
 	int done = 0;
@@ -143,36 +143,36 @@
 	return field;
 }
 
-double mean_array(double *a, int n)
+float mean_array(float *a, int n)
 {
     int i;
-    double sum = 0;
+    float sum = 0;
     for(i = 0; i < n; ++i) sum += a[i];
     return sum/n;
 }
 
-double variance_array(double *a, int n)
+float variance_array(float *a, int n)
 {
     int i;
-    double sum = 0;
-    double mean = mean_array(a, n);
+    float sum = 0;
+    float mean = mean_array(a, n);
     for(i = 0; i < n; ++i) sum += (a[i] - mean)*(a[i]-mean);
-    double variance = sum/n;
+    float variance = sum/n;
     return variance;
 }
 
-double constrain(double a, double max)
+float constrain(float a, float max)
 {
     if(a > abs(max)) return abs(max);
     if(a < -abs(max)) return -abs(max);
     return a;
 }
 
-void normalize_array(double *a, int n)
+void normalize_array(float *a, int n)
 {
     int i;
-    double mu = mean_array(a,n);
-    double sigma = sqrt(variance_array(a,n));
+    float mu = mean_array(a,n);
+    float sigma = sqrt(variance_array(a,n));
     for(i = 0; i < n; ++i){
         a[i] = (a[i] - mu)/sigma;
     }
@@ -180,7 +180,7 @@
     sigma = sqrt(variance_array(a,n));
 }
 
-void translate_array(double *a, int n, double s)
+void translate_array(float *a, int n, float s)
 {
     int i;
     for(i = 0; i < n; ++i){
@@ -188,18 +188,18 @@
     }
 }
 
-void scale_array(double *a, int n, double s)
+void scale_array(float *a, int n, float s)
 {
     int i;
     for(i = 0; i < n; ++i){
         a[i] *= s;
     }
 }
-int max_index(double *a, int n)
+int max_index(float *a, int n)
 {
     if(n <= 0) return -1;
     int i, max_i = 0;
-    double max = a[0];
+    float max = a[0];
     for(i = 1; i < n; ++i){
         if(a[i] > max){
             max = a[i];
@@ -209,20 +209,20 @@
     return max_i;
 }
 
-double rand_normal()
+float rand_normal()
 {
     int i;
-    double sum= 0;
-    for(i = 0; i < 12; ++i) sum += (double)rand()/RAND_MAX;
+    float sum= 0;
+    for(i = 0; i < 12; ++i) sum += (float)rand()/RAND_MAX;
     return sum-6.;
 }
 
-double **one_hot_encode(double *a, int n, int k)
+float **one_hot_encode(float *a, int n, int k)
 {
     int i;
-    double **t = calloc(n, sizeof(double*));
+    float **t = calloc(n, sizeof(float*));
     for(i = 0; i < n; ++i){
-        t[i] = calloc(k, sizeof(double));
+        t[i] = calloc(k, sizeof(float));
         int index = (int)a[i];
         t[i][index] = 1;
     }
diff --git a/src/utils.h b/src/utils.h
index cf38016..8185107 100644
--- a/src/utils.h
+++ b/src/utils.h
@@ -13,15 +13,15 @@
 list *parse_csv_line(char *line);
 char *copy_string(char *s);
 int count_fields(char *line);
-double *parse_fields(char *line, int n);
-void normalize_array(double *a, int n);
-void scale_array(double *a, int n, double s);
-void translate_array(double *a, int n, double s);
-int max_index(double *a, int n);
-double constrain(double a, double max);
-double rand_normal();
-double mean_array(double *a, int n);
-double variance_array(double *a, int n);
-double **one_hot_encode(double *a, int n, int k);
+float *parse_fields(char *line, int n);
+void normalize_array(float *a, int n);
+void scale_array(float *a, int n, float s);
+void translate_array(float *a, int n, float s);
+int max_index(float *a, int n);
+float constrain(float a, float max);
+float rand_normal();
+float mean_array(float *a, int n);
+float variance_array(float *a, int n);
+float **one_hot_encode(float *a, int n, int k);
 #endif
 

--
Gitblit v1.10.0