Joseph Redmon
2014-01-29 f7a17f82eb43de864a4f980f235055da9685eef8
Convolutional layers working w/ matrices
26 files modified
822 ■■■■■ changed files
Makefile 2 ●●● patch | view | raw | blame | history
nist_basic.cfg 2 ●●● patch | view | raw | blame | history
src/activations.c 4 ●●●● patch | view | raw | blame | history
src/activations.h 4 ●●●● patch | view | raw | blame | history
src/connected_layer.c 58 ●●●● patch | view | raw | blame | history
src/connected_layer.h 24 ●●●● patch | view | raw | blame | history
src/convolutional_layer.c 90 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.h 24 ●●●● patch | view | raw | blame | history
src/data.c 20 ●●●● patch | view | raw | blame | history
src/image.c 68 ●●●● patch | view | raw | blame | history
src/image.h 20 ●●●● patch | view | raw | blame | history
src/matrix.c 20 ●●●● patch | view | raw | blame | history
src/matrix.h 6 ●●●● patch | view | raw | blame | history
src/maxpool_layer.c 28 ●●●● patch | view | raw | blame | history
src/maxpool_layer.h 8 ●●●● patch | view | raw | blame | history
src/mini_blas.c 93 ●●●● patch | view | raw | blame | history
src/mini_blas.h 24 ●●●●● patch | view | raw | blame | history
src/network.c 70 ●●●● patch | view | raw | blame | history
src/network.h 24 ●●●● patch | view | raw | blame | history
src/option_list.c 2 ●●● patch | view | raw | blame | history
src/option_list.h 2 ●●● patch | view | raw | blame | history
src/softmax_layer.c 28 ●●●● patch | view | raw | blame | history
src/softmax_layer.h 8 ●●●● patch | view | raw | blame | history
src/tests.c 129 ●●●●● patch | view | raw | blame | history
src/utils.c 44 ●●●● patch | view | raw | blame | history
src/utils.h 20 ●●●● patch | view | raw | blame | history
Makefile
@@ -1,6 +1,6 @@
CC=gcc
COMMON=-Wall `pkg-config --cflags opencv`
CFLAGS= $(COMMON) -Ofast -ffast-math -flto
CFLAGS= $(COMMON) -O3 -ffast-math -flto
UNAME = $(shell uname)
ifeq ($(UNAME), Darwin)
COMMON += -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
nist_basic.cfg
@@ -3,7 +3,7 @@
height=28
channels=1
filters=20
size=5
size=11
stride=1
activation=linear
src/activations.c
@@ -15,7 +15,7 @@
    return RELU;
}
double activate(double x, ACTIVATION a){
float activate(float x, ACTIVATION a){
    switch(a){
        case LINEAR:
            return x;
@@ -30,7 +30,7 @@
    }
    return 0;
}
double gradient(double x, ACTIVATION a){
float gradient(float x, ACTIVATION a){
    switch(a){
        case LINEAR:
            return 1;
src/activations.h
@@ -7,8 +7,8 @@
ACTIVATION get_activation(char *s);
double activate(double x, ACTIVATION a);
double gradient(double x, ACTIVATION a);
float activate(float x, ACTIVATION a);
float gradient(float x, ACTIVATION a);
#endif
src/connected_layer.c
@@ -15,19 +15,19 @@
    layer->inputs = inputs;
    layer->outputs = outputs;
    layer->output = calloc(outputs, sizeof(double*));
    layer->delta = calloc(outputs, sizeof(double*));
    layer->output = calloc(outputs, sizeof(float*));
    layer->delta = calloc(outputs, sizeof(float*));
    layer->weight_updates = calloc(inputs*outputs, sizeof(double));
    layer->weight_momentum = calloc(inputs*outputs, sizeof(double));
    layer->weights = calloc(inputs*outputs, sizeof(double));
    double scale = 2./inputs;
    layer->weight_updates = calloc(inputs*outputs, sizeof(float));
    layer->weight_momentum = calloc(inputs*outputs, sizeof(float));
    layer->weights = calloc(inputs*outputs, sizeof(float));
    float scale = 2./inputs;
    for(i = 0; i < inputs*outputs; ++i)
        layer->weights[i] = rand_normal()*scale;
    layer->bias_updates = calloc(outputs, sizeof(double));
    layer->bias_momentum = calloc(outputs, sizeof(double));
    layer->biases = calloc(outputs, sizeof(double));
    layer->bias_updates = calloc(outputs, sizeof(float));
    layer->bias_momentum = calloc(outputs, sizeof(float));
    layer->biases = calloc(outputs, sizeof(float));
    for(i = 0; i < outputs; ++i)
        //layer->biases[i] = rand_normal()*scale + scale;
        layer->biases[i] = 0;
@@ -36,7 +36,7 @@
    return layer;
}
void update_connected_layer(connected_layer layer, double step, double momentum, double decay)
void update_connected_layer(connected_layer layer, float step, float momentum, float decay)
{
    int i;
    for(i = 0; i < layer.outputs; ++i){
@@ -47,27 +47,27 @@
        layer.weight_momentum[i] = step*(layer.weight_updates[i] - decay*layer.weights[i]) + momentum*layer.weight_momentum[i];
        layer.weights[i] += layer.weight_momentum[i];
    }
    memset(layer.bias_updates, 0, layer.outputs*sizeof(double));
    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(double));
    memset(layer.bias_updates, 0, layer.outputs*sizeof(float));
    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float));
}
void forward_connected_layer(connected_layer layer, double *input)
void forward_connected_layer(connected_layer layer, float *input)
{
    int i;
    memcpy(layer.output, layer.biases, layer.outputs*sizeof(double));
    memcpy(layer.output, layer.biases, layer.outputs*sizeof(float));
    int m = 1;
    int k = layer.inputs;
    int n = layer.outputs;
    double *a = input;
    double *b = layer.weights;
    double *c = layer.output;
    float *a = input;
    float *b = layer.weights;
    float *c = layer.output;
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
    for(i = 0; i < layer.outputs; ++i){
        layer.output[i] = activate(layer.output[i], layer.activation);
    }
}
void learn_connected_layer(connected_layer layer, double *input)
void learn_connected_layer(connected_layer layer, float *input)
{
    int i;
    for(i = 0; i < layer.outputs; ++i){
@@ -77,28 +77,28 @@
    int m = layer.inputs;
    int k = 1;
    int n = layer.outputs;
    double *a = input;
    double *b = layer.delta;
    double *c = layer.weight_updates;
    float *a = input;
    float *b = layer.delta;
    float *c = layer.weight_updates;
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
void backward_connected_layer(connected_layer layer, double *input, double *delta)
void backward_connected_layer(connected_layer layer, float *input, float *delta)
{
    memset(delta, 0, layer.inputs*sizeof(double));
    memset(delta, 0, layer.inputs*sizeof(float));
    int m = layer.inputs;
    int k = layer.outputs;
    int n = 1;
    double *a = layer.weights;
    double *b = layer.delta;
    double *c = delta;
    float *a = layer.weights;
    float *b = layer.delta;
    float *c = delta;
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
/*
   void forward_connected_layer(connected_layer layer, double *input)
   void forward_connected_layer(connected_layer layer, float *input)
   {
   int i, j;
   for(i = 0; i < layer.outputs; ++i){
@@ -109,7 +109,7 @@
   layer.output[i] = activate(layer.output[i], layer.activation);
   }
   }
   void learn_connected_layer(connected_layer layer, double *input)
   void learn_connected_layer(connected_layer layer, float *input)
   {
   int i, j;
   for(i = 0; i < layer.outputs; ++i){
@@ -120,7 +120,7 @@
   }
   }
   }
   void backward_connected_layer(connected_layer layer, double *input, double *delta)
   void backward_connected_layer(connected_layer layer, float *input, float *delta)
   {
   int i, j;
src/connected_layer.h
@@ -6,17 +6,17 @@
typedef struct{
    int inputs;
    int outputs;
    double *weights;
    double *biases;
    float *weights;
    float *biases;
    double *weight_updates;
    double *bias_updates;
    float *weight_updates;
    float *bias_updates;
    double *weight_momentum;
    double *bias_momentum;
    float *weight_momentum;
    float *bias_momentum;
    double *output;
    double *delta;
    float *output;
    float *delta;
    ACTIVATION activation;
@@ -24,10 +24,10 @@
connected_layer *make_connected_layer(int inputs, int outputs, ACTIVATION activation);
void forward_connected_layer(connected_layer layer, double *input);
void backward_connected_layer(connected_layer layer, double *input, double *delta);
void learn_connected_layer(connected_layer layer, double *input);
void update_connected_layer(connected_layer layer, double step, double momentum, double decay);
void forward_connected_layer(connected_layer layer, float *input);
void backward_connected_layer(connected_layer layer, float *input, float *delta);
void learn_connected_layer(connected_layer layer, float *input);
void update_connected_layer(connected_layer layer, float step, float momentum, float decay);
#endif
src/convolutional_layer.c
@@ -9,7 +9,7 @@
    h = layer.out_h;
    w = layer.out_w;
    c = layer.n;
    return double_to_image(h,w,c,layer.output);
    return float_to_image(h,w,c,layer.output);
}
image get_convolutional_delta(convolutional_layer layer)
@@ -18,7 +18,7 @@
    h = layer.out_h;
    w = layer.out_w;
    c = layer.n;
    return double_to_image(h,w,c,layer.delta);
    return float_to_image(h,w,c,layer.delta);
}
convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation)
@@ -34,14 +34,14 @@
    layer->stride = stride;
    layer->size = size;
    layer->filters = calloc(c*n*size*size, sizeof(double));
    layer->filter_updates = calloc(c*n*size*size, sizeof(double));
    layer->filter_momentum = calloc(c*n*size*size, sizeof(double));
    layer->filters = calloc(c*n*size*size, sizeof(float));
    layer->filter_updates = calloc(c*n*size*size, sizeof(float));
    layer->filter_momentum = calloc(c*n*size*size, sizeof(float));
    layer->biases = calloc(n, sizeof(double));
    layer->bias_updates = calloc(n, sizeof(double));
    layer->bias_momentum = calloc(n, sizeof(double));
    double scale = 2./(size*size);
    layer->biases = calloc(n, sizeof(float));
    layer->bias_updates = calloc(n, sizeof(float));
    layer->bias_momentum = calloc(n, sizeof(float));
    float scale = 2./(size*size);
    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = rand_normal()*scale;
    for(i = 0; i < n; ++i){
        //layer->biases[i] = rand_normal()*scale + scale;
@@ -50,9 +50,9 @@
    out_h = (h-size)/stride + 1;
    out_w = (w-size)/stride + 1;
    layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(double));
    layer->output = calloc(out_h * out_w * n, sizeof(double));
    layer->delta  = calloc(out_h * out_w * n, sizeof(double));
    layer->col_image = calloc(out_h*out_w*size*size*c, sizeof(float));
    layer->output = calloc(out_h * out_w * n, sizeof(float));
    layer->delta  = calloc(out_h * out_w * n, sizeof(float));
    layer->activation = activation;
    layer->out_h = out_h;
    layer->out_w = out_w;
@@ -63,18 +63,18 @@
    return layer;
}
void forward_convolutional_layer(const convolutional_layer layer, double *in)
void forward_convolutional_layer(const convolutional_layer layer, float *in)
{
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
    int n = ((layer.h-layer.size)/layer.stride + 1)*
            ((layer.w-layer.size)/layer.stride + 1);
    memset(layer.output, 0, m*n*sizeof(double));
    memset(layer.output, 0, m*n*sizeof(float));
    double *a = layer.filters;
    double *b = layer.col_image;
    double *c = layer.output;
    float *a = layer.filters;
    float *b = layer.col_image;
    float *c = layer.output;
    im2col_cpu(in,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b);
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
@@ -94,7 +94,7 @@
    int i,j;
    int size = layer.out_h*layer.out_w;
    for(i = 0; i < layer.n; ++i){
        double sum = 0;
        float sum = 0;
        for(j = 0; j < size; ++j){
            sum += layer.delta[j+i*size];
        }
@@ -111,14 +111,33 @@
    int k = ((layer.h-layer.size)/layer.stride + 1)*
            ((layer.w-layer.size)/layer.stride + 1);
    double *a = layer.delta;
    double *b = layer.col_image;
    double *c = layer.filter_updates;
    float *a = layer.delta;
    float *b = layer.col_image;
    float *c = layer.filter_updates;
    gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
}
void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay)
void backward_convolutional_layer(convolutional_layer layer, float *delta)
{
    int m = layer.size*layer.size*layer.c;
    int k = layer.n;
    int n = ((layer.h-layer.size)/layer.stride + 1)*
            ((layer.w-layer.size)/layer.stride + 1);
    float *a = layer.filters;
    float *b = layer.delta;
    float *c = layer.col_image;
    memset(c, 0, m*n*sizeof(float));
    gemm(1,0,m,n,k,1,a,m,b,n,1,c,n);
    memset(delta, 0, layer.h*layer.w*layer.c*sizeof(float));
    col2im_cpu(c,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, delta);
}
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
{
    int i;
    int size = layer.size*layer.size*layer.c*layer.n;
@@ -133,9 +152,9 @@
}
/*
void backward_convolutional_layer2(convolutional_layer layer, double *input, double *delta)
void backward_convolutional_layer2(convolutional_layer layer, float *input, float *delta)
{
    image in_delta = double_to_image(layer.h, layer.w, layer.c, delta);
    image in_delta = float_to_image(layer.h, layer.w, layer.c, delta);
    image out_delta = get_convolutional_delta(layer);
    int i,j;
    for(i = 0; i < layer.n; ++i){
@@ -156,10 +175,10 @@
}
void learn_convolutional_layer(convolutional_layer layer, double *input)
void learn_convolutional_layer(convolutional_layer layer, float *input)
{
    int i;
    image in_image = double_to_image(layer.h, layer.w, layer.c, input);
    image in_image = float_to_image(layer.h, layer.w, layer.c, input);
    image out_delta = get_convolutional_delta(layer);
    gradient_delta_convolutional_layer(layer);
    for(i = 0; i < layer.n; ++i){
@@ -168,7 +187,7 @@
    }
}
void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay)
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
{
    int i,j;
    for(i = 0; i < layer.n; ++i){
@@ -190,21 +209,28 @@
void test_convolutional_layer()
{
    convolutional_layer l = *make_convolutional_layer(4,4,1,1,3,1,LINEAR);
    double input[] =    {1,2,3,4,
    float input[] =    {1,2,3,4,
                        5,6,7,8,
                        9,10,11,12,
                        13,14,15,16};
    double filter[] =   {.5, 0, .3,
    float filter[] =   {.5, 0, .3,
                        0  , 1,  0,
                        .2 , 0,  1};
    double delta[] =    {1, 2,
    float delta[] =    {1, 2,
                        3,  4};
    float in_delta[] = {.5,1,.3,.6,
                        5,6,7,8,
                        9,10,11,12,
                        13,14,15,16};
    l.filters = filter;
    forward_convolutional_layer(l, input);
    l.delta = delta;
    learn_convolutional_layer(l);
    image filter_updates = double_to_image(3,3,1,l.filter_updates);
    image filter_updates = float_to_image(3,3,1,l.filter_updates);
    print_image(filter_updates);
    printf("Delta:\n");
    backward_convolutional_layer(l, in_delta);
    pm(4,4,in_delta);
}
image get_convolutional_filter(convolutional_layer layer, int i)
@@ -212,7 +238,7 @@
    int h = layer.size;
    int w = layer.size;
    int c = layer.c;
    return double_to_image(h,w,c,layer.filters+i*h*w*c);
    return float_to_image(h,w,c,layer.filters+i*h*w*c);
}
void visualize_convolutional_layer(convolutional_layer layer, char *window)
src/convolutional_layer.h
@@ -10,28 +10,28 @@
    int n;
    int size;
    int stride;
    double *filters;
    double *filter_updates;
    double *filter_momentum;
    float *filters;
    float *filter_updates;
    float *filter_momentum;
    double *biases;
    double *bias_updates;
    double *bias_momentum;
    float *biases;
    float *bias_updates;
    float *bias_momentum;
    double *col_image;
    double *delta;
    double *output;
    float *col_image;
    float *delta;
    float *output;
    ACTIVATION activation;
} convolutional_layer;
convolutional_layer *make_convolutional_layer(int h, int w, int c, int n, int size, int stride, ACTIVATION activation);
void forward_convolutional_layer(const convolutional_layer layer, double *in);
void forward_convolutional_layer(const convolutional_layer layer, float *in);
void learn_convolutional_layer(convolutional_layer layer);
void update_convolutional_layer(convolutional_layer layer, double step, double momentum, double decay);
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay);
void visualize_convolutional_layer(convolutional_layer layer, char *window);
//void backward_convolutional_layer(convolutional_layer layer, double *input, double *delta);
void backward_convolutional_layer(convolutional_layer layer, float *delta);
//void backpropagate_convolutional_layer_convolve(image input, convolutional_layer layer);
//void visualize_convolutional_filters(convolutional_layer layer, char *window);
src/data.c
@@ -19,10 +19,10 @@
    return lines;
}
void fill_truth(char *path, char **labels, int k, double *truth)
void fill_truth(char *path, char **labels, int k, float *truth)
{
    int i;
    memset(truth, 0, k*sizeof(double));
    memset(truth, 0, k*sizeof(float));
    for(i = 0; i < k; ++i){
        if(strstr(path, labels[i])){
            truth[i] = 1;
@@ -36,7 +36,7 @@
    data d;
    d.shallow = 0;
    d.X.rows = n;
    d.X.vals = calloc(d.X.rows, sizeof(double*));
    d.X.vals = calloc(d.X.rows, sizeof(float*));
    d.y = make_matrix(n, k);
    for(i = 0; i < n; ++i){
@@ -106,8 +106,8 @@
    data d;
    d.shallow = 0;
    matrix X = csv_to_matrix(filename);
    double *truth_1d = pop_column(&X, target);
    double **truth = one_hot_encode(truth_1d, X.rows, k);
    float *truth_1d = pop_column(&X, target);
    float **truth = one_hot_encode(truth_1d, X.rows, k);
    matrix y;
    y.rows = X.rows;
    y.cols = k;
@@ -123,7 +123,7 @@
    int i;
    for(i = d.X.rows-1; i > 0; --i){
        int index = rand()%i;
        double *swap = d.X.vals[index];
        float *swap = d.X.vals[index];
        d.X.vals[index] = d.X.vals[i];
        d.X.vals[i] = swap;
@@ -156,10 +156,10 @@
    train.X.cols = test.X.cols = d.X.cols;
    train.y.cols = test.y.cols = d.y.cols;
    train.X.vals = calloc(train.X.rows, sizeof(double*));
    test.X.vals = calloc(test.X.rows, sizeof(double*));
    train.y.vals = calloc(train.y.rows, sizeof(double*));
    test.y.vals = calloc(test.y.rows, sizeof(double*));
    train.X.vals = calloc(train.X.rows, sizeof(float*));
    test.X.vals = calloc(test.X.rows, sizeof(float*));
    train.y.vals = calloc(train.y.rows, sizeof(float*));
    test.y.vals = calloc(test.y.rows, sizeof(float*));
    for(i = 0; i < start; ++i){
        train.X.vals[i] = d.X.vals[i];
src/image.c
@@ -16,7 +16,7 @@
    for(k = 0; k < source.c; ++k){
        for(i = 0; i < source.h; ++i){
            for(j = 0; j < source.w; ++j){
                double val = get_pixel(source, i,j,k);
                float val = get_pixel(source, i,j,k);
                set_pixel(dest, h+i, w+j, k, val);
            }
        }
@@ -45,14 +45,14 @@
void normalize_image(image p)
{
    double *min = calloc(p.c, sizeof(double));
    double *max = calloc(p.c, sizeof(double));
    float *min = calloc(p.c, sizeof(float));
    float *max = calloc(p.c, sizeof(float));
    int i,j;
    for(i = 0; i < p.c; ++i) min[i] = max[i] = p.data[i*p.h*p.w];
    for(j = 0; j < p.c; ++j){
        for(i = 0; i < p.h*p.w; ++i){
            double v = p.data[i+j*p.h*p.w];
            float v = p.data[i+j*p.h*p.w];
            if(v < min[j]) min[j] = v;
            if(v > max[j]) max[j] = v;
        }
@@ -72,17 +72,17 @@
    free(max);
}
double avg_image_layer(image m, int l)
float avg_image_layer(image m, int l)
{
    int i;
    double sum = 0;
    float sum = 0;
    for(i = 0; i < m.h*m.w; ++i){
        sum += m.data[l*m.h*m.w + i];
    }
    return sum/(m.h*m.w);
}
void threshold_image(image p, double t)
void threshold_image(image p, float t)
{
    int i;
    for(i = 0; i < p.w*p.h*p.c; ++i){
@@ -93,8 +93,8 @@
image copy_image(image p)
{
    image copy = p;
    copy.data = calloc(p.h*p.w*p.c, sizeof(double));
    memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(double));
    copy.data = calloc(p.h*p.w*p.c, sizeof(float));
    memcpy(copy.data, p.data, p.h*p.w*p.c*sizeof(float));
    return copy;
}
@@ -168,11 +168,11 @@
image make_image(int h, int w, int c)
{
    image out = make_empty_image(h,w,c);
    out.data = calloc(h*w*c, sizeof(double));
    out.data = calloc(h*w*c, sizeof(float));
    return out;
}
image double_to_image(int h, int w, int c, double *data)
image float_to_image(int h, int w, int c, float *data)
{
    image out = make_empty_image(h,w,c);
    out.data = data;
@@ -181,12 +181,12 @@
void zero_image(image m)
{
    memset(m.data, 0, m.h*m.w*m.c*sizeof(double));
    memset(m.data, 0, m.h*m.w*m.c*sizeof(float));
}
void zero_channel(image m, int c)
{
    memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(double));
    memset(&(m.data[c*m.h*m.w]), 0, m.h*m.w*sizeof(float));
}
void rotate_image(image m)
@@ -194,7 +194,7 @@
    int i,j;
    for(j = 0; j < m.c; ++j){
        for(i = 0; i < m.h*m.w/2; ++i){
            double swap = m.data[j*m.h*m.w + i];
            float swap = m.data[j*m.h*m.w + i];
            m.data[j*m.h*m.w + i] = m.data[j*m.h*m.w + (m.h*m.w-1 - i)];
            m.data[j*m.h*m.w + (m.h*m.w-1 - i)] = swap;
        }
@@ -212,19 +212,19 @@
    return out;
}
void add_scalar_image(image m, double s)
void add_scalar_image(image m, float s)
{
    int i;
    for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] += s;
}
void scale_image(image m, double s)
void scale_image(image m, float s)
{
    int i;
    for(i = 0; i < m.h*m.w*m.c; ++i) m.data[i] *= s;
}
image make_random_kernel(int size, int c, double scale)
image make_random_kernel(int size, int c, float scale)
{
    int pad;
    if((pad=(size%2==0))) ++size;
@@ -280,34 +280,34 @@
    return out;
}
double get_pixel(image m, int x, int y, int c)
float get_pixel(image m, int x, int y, int c)
{
    assert(x < m.h && y < m.w && c < m.c);
    return m.data[c*m.h*m.w + x*m.w + y];
}
double get_pixel_extend(image m, int x, int y, int c)
float get_pixel_extend(image m, int x, int y, int c)
{
    if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return 0;
    return get_pixel(m, x, y, c);
}
void set_pixel(image m, int x, int y, int c, double val)
void set_pixel(image m, int x, int y, int c, float val)
{
    assert(x < m.h && y < m.w && c < m.c);
    m.data[c*m.h*m.w + x*m.w + y] = val;
}
void set_pixel_extend(image m, int x, int y, int c, double val)
void set_pixel_extend(image m, int x, int y, int c, float val)
{
    if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return;
    set_pixel(m, x, y, c, val);
}
void add_pixel(image m, int x, int y, int c, double val)
void add_pixel(image m, int x, int y, int c, float val)
{
    assert(x < m.h && y < m.w && c < m.c);
    m.data[c*m.h*m.w + x*m.w + y] += val;
}
void add_pixel_extend(image m, int x, int y, int c, double val)
void add_pixel_extend(image m, int x, int y, int c, float val)
{
    if(x < 0 || x >= m.h || y < 0 || y >= m.w || c < 0 || c >= m.c) return;
    add_pixel(m, x, y, c, val);
@@ -329,7 +329,7 @@
    }
    for(x = xstart; x < xend; x += stride){
        for(y = ystart; y < yend; y += stride){
            double sum = 0;
            float sum = 0;
            for(i = 0; i < kernel.h; ++i){
                for(j = 0; j < kernel.w; ++j){
                    sum += get_pixel(kernel, i, j, kc)*get_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, mc);
@@ -340,9 +340,9 @@
    }
}
double single_convolve(image m, image kernel, int x, int y)
float single_convolve(image m, image kernel, int x, int y)
{
    double sum = 0;
    float sum = 0;
    int i, j, k;
    for(i = 0; i < kernel.h; ++i){
        for(j = 0; j < kernel.w; ++j){
@@ -366,7 +366,7 @@
    int j;
    for(i = 0; i < m.h; i += stride){
        for(j = 0; j < m.w; j += stride){
            double val = single_convolve(m, kernel, i, j);
            float val = single_convolve(m, kernel, i, j);
            set_pixel(out, i/stride, j/stride, channel, val);
        }
    }
@@ -380,20 +380,20 @@
    for(k = 0; k < m.c; ++k){
        for(i = 0; i < m.h; ++i){
            for(j = 0; j< m.w; ++j){
                double val = get_pixel(m, i, j, k);
                float val = get_pixel(m, i, j, k);
                set_pixel(out, i*stride, j*stride, k, val);
            }
        }
    }
}
void single_update(image m, image update, int x, int y, double error)
void single_update(image m, image update, int x, int y, float error)
{
    int i, j, k;
    for(i = 0; i < update.h; ++i){
        for(j = 0; j < update.w; ++j){
            for(k = 0; k < update.c; ++k){
                double val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k);
                float val = get_pixel_extend(m, x+i-update.h/2, y+j-update.w/2, k);
                add_pixel(update, i, j, k, val*error);
            }
        }
@@ -417,7 +417,7 @@
    }
    for(i = istart; i < iend; i += stride){
        for(j = jstart; j < jend; j += stride){
            double error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
            float error = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
            single_update(m, update, i, j, error);
        }
    }
@@ -428,13 +428,13 @@
    */
}
void single_back_convolve(image m, image kernel, int x, int y, double val)
void single_back_convolve(image m, image kernel, int x, int y, float val)
{
    int i, j, k;
    for(i = 0; i < kernel.h; ++i){
        for(j = 0; j < kernel.w; ++j){
            for(k = 0; k < kernel.c; ++k){
                double pval = get_pixel(kernel, i, j, k) * val;
                float pval = get_pixel(kernel, i, j, k) * val;
                add_pixel_extend(m, x+i-kernel.h/2, y+j-kernel.w/2, k, pval);
            }
        }
@@ -457,7 +457,7 @@
    }
    for(i = istart; i < iend; i += stride){
        for(j = jstart; j < jend; j += stride){
            double val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
            float val = get_pixel(out, (i-istart)/stride, (j-jstart)/stride, channel);
            single_back_convolve(m, kernel, i, j, val);
        }
    }
src/image.h
@@ -7,18 +7,18 @@
    int h;
    int w;
    int c;
    double *data;
    float *data;
} image;
void scale_image(image m, double s);
void add_scalar_image(image m, double s);
void scale_image(image m, float s);
void add_scalar_image(image m, float s);
void normalize_image(image p);
void z_normalize_image(image p);
void threshold_image(image p, double t);
void threshold_image(image p, float t);
void zero_image(image m);
void rotate_image(image m);
void subtract_image(image a, image b);
double avg_image_layer(image m, int l);
float avg_image_layer(image m, int l);
void embed_image(image source, image dest, int h, int w);
image collapse_image_layers(image source, int border);
@@ -30,14 +30,14 @@
image make_image(int h, int w, int c);
image make_empty_image(int h, int w, int c);
image make_random_image(int h, int w, int c);
image make_random_kernel(int size, int c, double scale);
image double_to_image(int h, int w, int c, double *data);
image make_random_kernel(int size, int c, float scale);
image float_to_image(int h, int w, int c, float *data);
image copy_image(image p);
image load_image(char *filename);
double get_pixel(image m, int x, int y, int c);
double get_pixel_extend(image m, int x, int y, int c);
void set_pixel(image m, int x, int y, int c, double val);
float get_pixel(image m, int x, int y, int c);
float get_pixel_extend(image m, int x, int y, int c);
void set_pixel(image m, int x, int y, int c, float val);
image get_image_layer(image m, int l);
src/matrix.c
@@ -13,7 +13,7 @@
    free(m.vals);
}
double matrix_accuracy(matrix truth, matrix guess)
float matrix_accuracy(matrix truth, matrix guess)
{
    int k = truth.cols;
    int i;
@@ -22,7 +22,7 @@
        int class = max_index(guess.vals[i], k);
        if(truth.vals[i][class]) ++count;
    }
    return (double)count/truth.rows;
    return (float)count/truth.rows;
}
void matrix_add_matrix(matrix from, matrix to)
@@ -42,9 +42,9 @@
    matrix m;
    m.rows = rows;
    m.cols = cols;
    m.vals = calloc(m.rows, sizeof(double *));
    m.vals = calloc(m.rows, sizeof(float *));
    for(i = 0; i < m.rows; ++i){
        m.vals[i] = calloc(m.cols, sizeof(double));
        m.vals[i] = calloc(m.cols, sizeof(float));
    }
    return m;
}
@@ -55,7 +55,7 @@
    matrix h;
    h.rows = n;
    h.cols = m->cols;
    h.vals = calloc(h.rows, sizeof(double *));
    h.vals = calloc(h.rows, sizeof(float *));
    for(i = 0; i < n; ++i){
        int index = rand()%m->rows;
        h.vals[i] = m->vals[index];
@@ -64,9 +64,9 @@
    return h;
}
double *pop_column(matrix *m, int c)
float *pop_column(matrix *m, int c)
{
    double *col = calloc(m->rows, sizeof(double));
    float *col = calloc(m->rows, sizeof(float));
    int i, j;
    for(i = 0; i < m->rows; ++i){
        col[i] = m->vals[i][c];
@@ -90,18 +90,18 @@
    int n = 0;
    int size = 1024;
    m.vals = calloc(size, sizeof(double*));
    m.vals = calloc(size, sizeof(float*));
    while((line = fgetl(fp))){
        if(m.cols == -1) m.cols = count_fields(line);
        if(n == size){
            size *= 2;
            m.vals = realloc(m.vals, size*sizeof(double*));
            m.vals = realloc(m.vals, size*sizeof(float*));
        }
        m.vals[n] = parse_fields(line, m.cols);
        free(line);
        ++n;
    }
    m.vals = realloc(m.vals, n*sizeof(double*));
    m.vals = realloc(m.vals, n*sizeof(float*));
    m.rows = n;
    return m;
}
src/matrix.h
@@ -2,7 +2,7 @@
#define MATRIX_H
typedef struct matrix{
    int rows, cols;
    double **vals;
    float **vals;
} matrix;
matrix make_matrix(int rows, int cols);
@@ -11,9 +11,9 @@
matrix csv_to_matrix(char *filename);
matrix hold_out_matrix(matrix *m, int n);
double matrix_accuracy(matrix truth, matrix guess);
float matrix_accuracy(matrix truth, matrix guess);
void matrix_add_matrix(matrix from, matrix to);
double *pop_column(matrix *m, int c);
float *pop_column(matrix *m, int c);
#endif
src/maxpool_layer.c
@@ -6,7 +6,7 @@
    int h = (layer.h-1)/layer.stride + 1;
    int w = (layer.w-1)/layer.stride + 1;
    int c = layer.c;
    return double_to_image(h,w,c,layer.output);
    return float_to_image(h,w,c,layer.output);
}
image get_maxpool_delta(maxpool_layer layer)
@@ -14,7 +14,7 @@
    int h = (layer.h-1)/layer.stride + 1;
    int w = (layer.w-1)/layer.stride + 1;
    int c = layer.c;
    return double_to_image(h,w,c,layer.delta);
    return float_to_image(h,w,c,layer.delta);
}
maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride)
@@ -25,41 +25,41 @@
    layer->w = w;
    layer->c = c;
    layer->stride = stride;
    layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double));
    layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(double));
    layer->output = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float));
    layer->delta = calloc(((h-1)/stride+1) * ((w-1)/stride+1) * c, sizeof(float));
    return layer;
}
void forward_maxpool_layer(const maxpool_layer layer, double *in)
void forward_maxpool_layer(const maxpool_layer layer, float *in)
{
    image input = double_to_image(layer.h, layer.w, layer.c, in);
    image input = float_to_image(layer.h, layer.w, layer.c, in);
    image output = get_maxpool_image(layer);
    int i,j,k;
    for(i = 0; i < output.h*output.w*output.c; ++i) output.data[i] = -DBL_MAX;
    for(k = 0; k < input.c; ++k){
        for(i = 0; i < input.h; ++i){
            for(j = 0; j < input.w; ++j){
                double val = get_pixel(input, i, j, k);
                double cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
                float val = get_pixel(input, i, j, k);
                float cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
                if(val > cur) set_pixel(output, i/layer.stride, j/layer.stride, k, val);
            }
        }
    }
}
void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta)
void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta)
{
    image input = double_to_image(layer.h, layer.w, layer.c, in);
    image input_delta = double_to_image(layer.h, layer.w, layer.c, delta);
    image input = float_to_image(layer.h, layer.w, layer.c, in);
    image input_delta = float_to_image(layer.h, layer.w, layer.c, delta);
    image output_delta = get_maxpool_delta(layer);
    image output = get_maxpool_image(layer);
    int i,j,k;
    for(k = 0; k < input.c; ++k){
        for(i = 0; i < input.h; ++i){
            for(j = 0; j < input.w; ++j){
                double val = get_pixel(input, i, j, k);
                double cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
                double d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k);
                float val = get_pixel(input, i, j, k);
                float cur = get_pixel(output, i/layer.stride, j/layer.stride, k);
                float d = get_pixel(output_delta, i/layer.stride, j/layer.stride, k);
                if(val == cur) {
                    set_pixel(input_delta, i, j, k, d);
                }
src/maxpool_layer.h
@@ -6,14 +6,14 @@
typedef struct {
    int h,w,c;
    int stride;
    double *delta;
    double *output;
    float *delta;
    float *output;
} maxpool_layer;
image get_maxpool_image(maxpool_layer layer);
maxpool_layer *make_maxpool_layer(int h, int w, int c, int stride);
void forward_maxpool_layer(const maxpool_layer layer, double *in);
void backward_maxpool_layer(const maxpool_layer layer, double *in, double *delta);
void forward_maxpool_layer(const maxpool_layer layer, float *in);
void backward_maxpool_layer(const maxpool_layer layer, float *in, float *delta);
#endif
src/mini_blas.c
@@ -1,8 +1,10 @@
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#include <time.h>
void pm(int M, int N, double *A)
void pm(int M, int N, float *A)
{
    int i,j;
    for(i =0 ; i < M; ++i){
@@ -14,28 +16,37 @@
    printf("\n");
}
void gemm(int TA, int TB, int M, int N, int K, double ALPHA,
                    double *A, int lda,
                    double *B, int ldb,
                    double BETA,
                    double *C, int ldc)
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
                    float *A, int lda,
                    float *B, int ldb,
                    float BETA,
                    float *C, int ldc)
{
    // Assume TA = 0, beta = 1 LULZ
    // Assume beta = 1 LULZ
    int i,j,k;
    if(TB && !TA){
        for(i = 0; i < M; ++i){
            for(j = 0; j < N; ++j){
                register double sum = 0;
                register float sum = 0;
                for(k = 0; k < K; ++k){
                    sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
                }
                C[i*ldc+j] += sum;
            }
        }
    }else if(TA && !TB){
        for(i = 0; i < M; ++i){
            for(k = 0; k < K; ++k){
                register float A_PART = ALPHA*A[k*lda+i];
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] += A_PART*B[k*ldb+j];
                }
            }
        }
    }else{
        for(i = 0; i < M; ++i){
            for(k = 0; k < K; ++k){
                register double A_PART = ALPHA*A[i*lda+k];
                register float A_PART = ALPHA*A[i*lda+k];
                for(j = 0; j < N; ++j){
                    C[i*ldc+j] += A_PART*B[k*ldb+j];
                }
@@ -44,7 +55,7 @@
    }
}
void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix)
void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix)
{
    int i;
    int mc = c;
@@ -64,7 +75,7 @@
        matrix[i] = image[pc*h*w+ph*w+pw];
    }
}
void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix)
void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix)
{
    int b,p;
    int blocks = ((h-size)/stride+1)*((w-size)/stride+1);
@@ -84,9 +95,9 @@
}
//From Berkeley Vision's Caffe!
void im2col_cpu(double* data_im, const int channels,
void im2col_cpu(float* data_im, const int channels,
        const int height, const int width, const int ksize, const int stride,
        double* data_col)
        float* data_col)
{
    int c,h,w;
    int height_col = (height - ksize) / stride + 1;
@@ -106,3 +117,59 @@
    }
}
void col2im_cpu(float* data_col, const int channels,
        const int height, const int width, const int ksize, const int stride,
        float* data_im)
{
    int c,h,w;
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    for ( c = 0; c < channels_col; ++c) {
        int w_offset = c % ksize;
        int h_offset = (c / ksize) % ksize;
        int c_im = c / ksize / ksize;
        for ( h = 0; h < height_col; ++h) {
            for ( w = 0; w < width_col; ++w) {
                data_im[(c_im * height + h * stride + h_offset) * width
                    + w * stride + w_offset]+= data_col[(c * height_col + h) * width_col + w];
            }
        }
    }
}
float *random_matrix(int rows, int cols)
{
    int i;
    float *m = calloc(rows*cols, sizeof(float));
    for(i = 0; i < rows*cols; ++i){
        m[i] = (float)rand()/RAND_MAX;
    }
    return m;
}
void time_random_matrix(int TA, int TB, int m, int k, int n)
{
    float *a = random_matrix(m,k);
    float *b = random_matrix(k,n);
    float *c = random_matrix(m,n);
    int i;
    clock_t start = clock(), end;
    for(i = 0; i<1000; ++i){
        gemm(TA,TB,m,n,k,1,a,k,b,n,1,c,n);
    }
    end = clock();
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (double)(end-start)/CLOCKS_PER_SEC);
}
void test_blas()
{
    time_random_matrix(0,0,100,100,100);
    time_random_matrix(1,0,100,100,100);
    time_random_matrix(0,1,100,100,100);
    time_random_matrix(0,1,1000,100,100);
    time_random_matrix(1,0,1000,100,100);
}
src/mini_blas.h
@@ -1,11 +1,15 @@
void pm(int M, int N, double *A);
void gemm(int TA, int TB, int M, int N, int K, double ALPHA,
                    double *A, int lda,
                    double *B, int ldb,
                    double BETA,
                    double *C, int ldc);
void im2row(double *image, int h, int w, int c, int size, int stride, double *matrix);
void im2col(double *image, int h, int w, int c, int size, int stride, double *matrix);
void im2col_cpu(double* data_im, const int channels,
void pm(int M, int N, float *A);
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
                    float *A, int lda,
                    float *B, int ldb,
                    float BETA,
                    float *C, int ldc);
void im2row(float *image, int h, int w, int c, int size, int stride, float *matrix);
void im2col(float *image, int h, int w, int c, int size, int stride, float *matrix);
void im2col_cpu(float* data_im, const int channels,
        const int height, const int width, const int ksize, const int stride,
        double* data_col);
        float* data_col);
void col2im_cpu(float* data_col, const int channels,
        const int height, const int width, const int ksize, const int stride,
        float* data_im);
void test_blas();
src/network.c
@@ -21,7 +21,7 @@
    return net;
}
void forward_network(network net, double *input)
void forward_network(network net, float *input)
{
    int i;
    for(i = 0; i < net.n; ++i){
@@ -48,7 +48,7 @@
    }
}
void update_network(network net, double step, double momentum, double decay)
void update_network(network net, float step, float momentum, float decay)
{
    int i;
    for(i = 0; i < net.n; ++i){
@@ -69,7 +69,7 @@
    }
}
double *get_network_output_layer(network net, int i)
float *get_network_output_layer(network net, int i)
{
    if(net.types[i] == CONVOLUTIONAL){
        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -86,12 +86,12 @@
    }
    return 0;
}
double *get_network_output(network net)
float *get_network_output(network net)
{
    return get_network_output_layer(net, net.n-1);
}
double *get_network_delta_layer(network net, int i)
float *get_network_delta_layer(network net, int i)
{
    if(net.types[i] == CONVOLUTIONAL){
        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -109,16 +109,16 @@
    return 0;
}
double *get_network_delta(network net)
float *get_network_delta(network net)
{
    return get_network_delta_layer(net, net.n-1);
}
double calculate_error_network(network net, double *truth)
float calculate_error_network(network net, float *truth)
{
    double sum = 0;
    double *delta = get_network_delta(net);
    double *out = get_network_output(net);
    float sum = 0;
    float *delta = get_network_delta(net);
    float *out = get_network_output(net);
    int i, k = get_network_output_size(net);
    for(i = 0; i < k; ++i){
        delta[i] = truth[i] - out[i];
@@ -129,17 +129,17 @@
int get_predicted_class_network(network net)
{
    double *out = get_network_output(net);
    float *out = get_network_output(net);
    int k = get_network_output_size(net);
    return max_index(out, k);
}
double backward_network(network net, double *input, double *truth)
float backward_network(network net, float *input, float *truth)
{
    double error = calculate_error_network(net, truth);
    float error = calculate_error_network(net, truth);
    int i;
    double *prev_input;
    double *prev_delta;
    float *prev_input;
    float *prev_delta;
    for(i = net.n-1; i >= 0; --i){
        if(i == 0){
            prev_input = input;
@@ -152,7 +152,7 @@
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            learn_convolutional_layer(layer);
            //learn_convolutional_layer(layer);
            //if(i != 0) backward_convolutional_layer(layer, prev_input, prev_delta);
            if(i != 0) backward_convolutional_layer(layer, prev_delta);
        }
        else if(net.types[i] == MAXPOOL){
            maxpool_layer layer = *(maxpool_layer *)net.layers[i];
@@ -171,49 +171,49 @@
    return error;
}
double train_network_datum(network net, double *x, double *y, double step, double momentum, double decay)
float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay)
{
        forward_network(net, x);
        int class = get_predicted_class_network(net);
        double error = backward_network(net, x, y);
        float error = backward_network(net, x, y);
        update_network(net, step, momentum, decay);
        //return (y[class]?1:0);
        return error;
}
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay)
float train_network_sgd(network net, data d, int n, float step, float momentum,float decay)
{
    int i;
    double error = 0;
    float error = 0;
    for(i = 0; i < n; ++i){
        int index = rand()%d.X.rows;
        error += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay);
        //if((i+1)%10 == 0){
        //    printf("%d: %f\n", (i+1), (double)correct/(i+1));
        //    printf("%d: %f\n", (i+1), (float)correct/(i+1));
        //}
    }
    return error/n;
}
double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
float train_network_batch(network net, data d, int n, float step, float momentum,float decay)
{
    int i;
    int correct = 0;
    for(i = 0; i < n; ++i){
        int index = rand()%d.X.rows;
        double *x = d.X.vals[index];
        double *y = d.y.vals[index];
        float *x = d.X.vals[index];
        float *y = d.y.vals[index];
        forward_network(net, x);
        int class = get_predicted_class_network(net);
        backward_network(net, x, y);
        correct += (y[class]?1:0);
    }
    update_network(net, step, momentum, decay);
    return (double)correct/n;
    return (float)correct/n;
}
void train_network(network net, data d, double step, double momentum, double decay)
void train_network(network net, data d, float step, float momentum, float decay)
{
    int i;
    int correct = 0;
@@ -226,7 +226,7 @@
    }
    visualize_network(net);
    cvWaitKey(100);
    printf("Accuracy: %f\n", (double)correct/d.X.rows);
    printf("Accuracy: %f\n", (float)correct/d.X.rows);
}
int get_network_output_size_layer(network net, int i)
@@ -294,10 +294,10 @@
    } 
}
double *network_predict(network net, double *input)
float *network_predict(network net, float *input)
{
    forward_network(net, input);
    double *out = get_network_output(net);
    float *out = get_network_output(net);
    return out;
}
@@ -307,7 +307,7 @@
    int k = get_network_output_size(net);
    matrix pred = make_matrix(test.X.rows, k);
    for(i = 0; i < test.X.rows; ++i){
        double *out = network_predict(net, test.X.vals[i]);
        float *out = network_predict(net, test.X.vals[i]);
        for(j = 0; j < k; ++j){
            pred.vals[i][j] = out[j];
        }
@@ -319,7 +319,7 @@
{
    int i,j;
    for(i = 0; i < net.n; ++i){
        double *output = 0;
        float *output = 0;
        int n = 0;
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
@@ -343,8 +343,8 @@
            output = layer.output;
            n = layer.inputs;
        }
        double mean = mean_array(output, n);
        double vari = variance_array(output, n);
        float mean = mean_array(output, n);
        float vari = variance_array(output, n);
        fprintf(stderr, "Layer %d - Mean: %f, Variance: %f\n",i,mean, vari);
        if(n > 100) n = 100;
        for(j = 0; j < n; ++j) fprintf(stderr, "%f, ", output[j]);
@@ -353,10 +353,10 @@
    }
}
double network_accuracy(network net, data d)
float network_accuracy(network net, data d)
{
    matrix guess = network_predict_data(net, d);
    double acc = matrix_accuracy(d.y, guess);
    float acc = matrix_accuracy(d.y, guess);
    free_matrix(guess);
    return acc;
}
src/network.h
@@ -17,22 +17,22 @@
    void **layers;
    LAYER_TYPE *types;
    int outputs;
    double *output;
    float *output;
} network;
network make_network(int n);
void forward_network(network net, double *input);
double backward_network(network net, double *input, double *truth);
void update_network(network net, double step, double momentum, double decay);
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
void train_network(network net, data d, double step, double momentum, double decay);
void forward_network(network net, float *input);
float backward_network(network net, float *input, float *truth);
void update_network(network net, float step, float momentum, float decay);
float train_network_sgd(network net, data d, int n, float step, float momentum,float decay);
float train_network_batch(network net, data d, int n, float step, float momentum,float decay);
void train_network(network net, data d, float step, float momentum, float decay);
matrix network_predict_data(network net, data test);
double network_accuracy(network net, data d);
double *get_network_output(network net);
double *get_network_output_layer(network net, int i);
double *get_network_delta_layer(network net, int i);
double *get_network_delta(network net);
float network_accuracy(network net, data d);
float *get_network_output(network net);
float *get_network_output_layer(network net, int i);
float *get_network_delta_layer(network net, int i);
float *get_network_delta(network net);
int get_network_output_size_layer(network net, int i);
int get_network_output_size(network net);
image get_network_image(network net);
src/option_list.c
@@ -59,7 +59,7 @@
    return def;
}
double option_find_double(list *l, char *key, double def)
float option_find_float(list *l, char *key, float def)
{
    char *v = option_find(l, key);
    if(v) return atof(v);
src/option_list.h
@@ -6,7 +6,7 @@
char *option_find(list *l, char *key);
char *option_find_str(list *l, char *key, char *def);
int option_find_int(list *l, char *key, int def);
double option_find_double(list *l, char *key, double def);
float option_find_float(list *l, char *key, float def);
void option_unused(list *l);
#endif
src/softmax_layer.c
@@ -8,15 +8,16 @@
    fprintf(stderr, "Softmax Layer: %d inputs\n", inputs);
    softmax_layer *layer = calloc(1, sizeof(softmax_layer));
    layer->inputs = inputs;
    layer->output = calloc(inputs, sizeof(double));
    layer->delta = calloc(inputs, sizeof(double));
    layer->output = calloc(inputs, sizeof(float));
    layer->delta = calloc(inputs, sizeof(float));
    return layer;
}
void forward_softmax_layer(const softmax_layer layer, double *input)
/* UNSTABLE!
void forward_softmax_layer(const softmax_layer layer, float *input)
{
    int i;
    double sum = 0;
    float sum = 0;
    for(i = 0; i < layer.inputs; ++i){
        sum += exp(input[i]);
    }
@@ -24,8 +25,25 @@
        layer.output[i] = exp(input[i])/sum;
    }
}
*/
void forward_softmax_layer(const softmax_layer layer, float *input)
{
    int i;
    float sum = 0;
    float largest = 0;
    for(i = 0; i < layer.inputs; ++i){
        if(input[i] > largest) largest = input[i];
    }
    for(i = 0; i < layer.inputs; ++i){
        sum += exp(input[i]-largest);
    }
    sum = largest+log(sum);
    for(i = 0; i < layer.inputs; ++i){
        layer.output[i] = exp(input[i]-sum);
    }
}
void backward_softmax_layer(const softmax_layer layer, double *input, double *delta)
void backward_softmax_layer(const softmax_layer layer, float *input, float *delta)
{
    int i;
    for(i = 0; i < layer.inputs; ++i){
src/softmax_layer.h
@@ -3,12 +3,12 @@
typedef struct {
    int inputs;
    double *delta;
    double *output;
    float *delta;
    float *output;
} softmax_layer;
softmax_layer *make_softmax_layer(int inputs);
void forward_softmax_layer(const softmax_layer layer, double *input);
void backward_softmax_layer(const softmax_layer layer, double *input, double *delta);
void forward_softmax_layer(const softmax_layer layer, float *input);
void backward_softmax_layer(const softmax_layer layer, float *input, float *delta);
#endif
src/tests.c
@@ -14,6 +14,9 @@
#include <stdlib.h>
#include <stdio.h>
#define _GNU_SOURCE
#include <fenv.h>
void test_convolve()
{
    image dog = load_image("dog.jpg");
@@ -26,7 +29,7 @@
        convolve(dog, kernel, 1, 0, edge, 1);
    }
    end = clock();
    printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
    printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
    show_image_layers(edge, "Test Convolve");
}
@@ -38,11 +41,11 @@
    int size = 11;
    int stride = 4;
    int n = 40;
    double *filters = make_random_image(size, size, dog.c*n).data;
    float *filters = make_random_image(size, size, dog.c*n).data;
    int mw = ((dog.h-size)/stride+1)*((dog.w-size)/stride+1);
    int mh = (size*size*dog.c);
    double *matrix = calloc(mh*mw, sizeof(double));
    float *matrix = calloc(mh*mw, sizeof(float));
    image edge = make_image((dog.h-size)/stride+1, (dog.w-size)/stride+1, n);
@@ -54,7 +57,7 @@
        gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
    }
    end = clock();
    printf("Convolutions: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
    printf("Convolutions: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
    show_image_layers(edge, "Test Convolve");
    cvWaitKey(0);
}
@@ -72,11 +75,11 @@
    int n = 1;
    int stride = 1;
    int size = 3;
    double eps = .00000001;
    float eps = .00000001;
    image test = make_random_image(5,5, 1);
    convolutional_layer layer = *make_convolutional_layer(test.h,test.w,test.c, n, size, stride, RELU);
    image out = get_convolutional_image(layer);
    double **jacobian = calloc(test.h*test.w*test.c, sizeof(double));
    float **jacobian = calloc(test.h*test.w*test.c, sizeof(float));
    
    forward_convolutional_layer(layer, test.data);
    image base = copy_image(out);
@@ -90,19 +93,19 @@
        jacobian[i] = partial.data;
        test.data[i] -= eps;
    }
    double **jacobian2 = calloc(out.h*out.w*out.c, sizeof(double));
    float **jacobian2 = calloc(out.h*out.w*out.c, sizeof(float));
    image in_delta = make_image(test.h, test.w, test.c);
    image out_delta = get_convolutional_delta(layer);
    for(i = 0; i < out.h*out.w*out.c; ++i){
        out_delta.data[i] = 1;
        //backward_convolutional_layer(layer, test.data, in_delta.data);
        backward_convolutional_layer(layer, in_delta.data);
        image partial = copy_image(in_delta);
        jacobian2[i] = partial.data;
        out_delta.data[i] = 0;
    }
    int j;
    double *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double));
    double *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(double));
    float *j1 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float));
    float *j2 = calloc(test.h*test.w*test.c*out.h*out.w*out.c, sizeof(float));
    for(i = 0; i < test.h*test.w*test.c; ++i){
        for(j =0 ; j < out.h*out.w*out.c; ++j){
            j1[i*out.h*out.w*out.c + j] = jacobian[i][j];
@@ -112,12 +115,11 @@
    }
    image mj1 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1);
    image mj2 = double_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2);
    image mj1 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j1);
    image mj2 = float_to_image(test.w*test.h*test.c, out.w*out.h*out.c, 1, j2);
    printf("%f %f\n", avg_image_layer(mj1,0), avg_image_layer(mj2,0));
    show_image(mj1, "forward jacobian");
    show_image(mj2, "backward jacobian");
}
void test_load()
@@ -145,7 +147,7 @@
        rotate_image(dog);
    }
    end = clock();
    printf("Rotations: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
    printf("Rotations: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
    show_image(dog, "Test Rotate");
    image random = make_random_image(3,3,3);
@@ -159,18 +161,18 @@
void test_parser()
{
    network net = parse_network_cfg("test_parser.cfg");
    double input[1];
    float input[1];
    int count = 0;
        
    double avgerr = 0;
    float avgerr = 0;
    while(++count < 100000000){
        double v = ((double)rand()/RAND_MAX);
        double truth = v*v;
        float v = ((float)rand()/RAND_MAX);
        float truth = v*v;
        input[0] = v;
        forward_network(net, input);
        double *out = get_network_output(net);
        double *delta = get_network_delta(net);
        double err = pow((out[0]-truth),2.);
        float *out = get_network_output(net);
        float *delta = get_network_delta(net);
        float err = pow((out[0]-truth),2.);
        avgerr = .99 * avgerr + .01 * err;
        if(count % 1000000 == 0) printf("%f %f :%f AVG %f \n", truth, out[0], err, avgerr);
        delta[0] = truth - out[0];
@@ -192,9 +194,9 @@
    srand(0);
    int i = 0;
    char *labels[] = {"cat","dog"};
    double lr = .00001;
    double momentum = .9;
    double decay = 0.01;
    float lr = .00001;
    float momentum = .9;
    float decay = 0.01;
    while(i++ < 1000 || 1){
        data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
        train_network(net, train, lr, momentum, decay);
@@ -207,32 +209,33 @@
{
    srand(444444);
    srand(888888);
    network net = parse_network_cfg("nist_basic.cfg");
    network net = parse_network_cfg("nist.cfg");
    data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
    data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
    normalize_data_rows(train);
    normalize_data_rows(test);
    //randomize_data(train);
    int count = 0;
    double lr = .0005;
    double momentum = .9;
    double decay = 0.01;
    float lr = .0005;
    float momentum = .9;
    float decay = 0.01;
    clock_t start = clock(), end;
    while(++count <= 100){
        visualize_network(net);
        double loss = train_network_sgd(net, train, 10000, lr, momentum, decay);
        //visualize_network(net);
        float loss = train_network_sgd(net, train, 1000, lr, momentum, decay);
        printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, loss, lr, momentum, decay);
        end = clock();
        printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
        printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
        start=end;
        cvWaitKey(100);
        //lr /= 2; 
        if(count%5 == 0){
            double train_acc = network_accuracy(net, train);
            float train_acc = network_accuracy(net, train);
            fprintf(stderr, "\nTRAIN: %f\n", train_acc);
            double test_acc = network_accuracy(net, test);
            float test_acc = network_accuracy(net, test);
            fprintf(stderr, "TEST: %f\n\n", test_acc);
            printf("%d, %f, %f\n", count, train_acc, test_acc);
            lr *= .5;
        }
    }
}
@@ -253,24 +256,24 @@
    int n = 30;
    for(i = 0; i < n; ++i){
        int count = 0;
        double lr = .0005;
        double momentum = .9;
        double decay = .01;
        float lr = .0005;
        float momentum = .9;
        float decay = .01;
        network net = parse_network_cfg("nist.cfg");
        while(++count <= 15){
            double acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
            float acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
            printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay );
            lr /= 2; 
        }
        matrix partial = network_predict_data(net, test);
        double acc = matrix_accuracy(test.y, partial);
        float acc = matrix_accuracy(test.y, partial);
        printf("Model Accuracy: %lf\n", acc);
        matrix_add_matrix(partial, prediction);
        acc = matrix_accuracy(test.y, prediction);
        printf("Current Ensemble Accuracy: %lf\n", acc);
        free_matrix(partial);
    }
    double acc = matrix_accuracy(test.y, prediction);
    float acc = matrix_accuracy(test.y, prediction);
    printf("Full Ensemble Accuracy: %lf\n", acc);
}
@@ -279,19 +282,19 @@
    network net = parse_network_cfg("connected.cfg");
    matrix m = csv_to_matrix("train.csv");
    //matrix ho = hold_out_matrix(&m, 2500);
    double *truth = pop_column(&m, 0);
    //double *ho_truth = pop_column(&ho, 0);
    float *truth = pop_column(&m, 0);
    //float *ho_truth = pop_column(&ho, 0);
    int i;
    clock_t start = clock(), end;
    int count = 0;
    while(++count <= 300){
        for(i = 0; i < m.rows; ++i){
            int index = rand()%m.rows;
            //image p = double_to_image(1690,1,1,m.vals[index]);
            //image p = float_to_image(1690,1,1,m.vals[index]);
            //normalize_image(p);
            forward_network(net, m.vals[index]);
            double *out = get_network_output(net);
            double *delta = get_network_delta(net);
            float *out = get_network_output(net);
            float *delta = get_network_delta(net);
            //printf("%f\n", out[0]);
            delta[0] = truth[index] - out[0];
            // printf("%f\n", delta[0]);
@@ -299,8 +302,8 @@
            //backward_network(net, m.vals[index], );
            update_network(net, .00001, 0,0);
        }
        //double test_acc = error_network(net, m, truth);
        //double valid_acc = error_network(net, ho, ho_truth);
        //float test_acc = error_network(net, m, truth);
        //float valid_acc = error_network(net, ho, ho_truth);
        //printf("%f, %f\n", test_acc, valid_acc);
        //fprintf(stderr, "%5d: %f Valid: %f\n",count, test_acc, valid_acc);
        //if(valid_acc > .70) break;
@@ -311,12 +314,12 @@
    truth = pop_column(&test, 0);
    for(i = 0; i < test.rows; ++i){
        forward_network(net, test.vals[i]);
        double *out = get_network_output(net);
        float *out = get_network_output(net);
        if(fabs(out[0]) < .5) fprintf(fp, "0\n");
        else fprintf(fp, "1\n");
    }
    fclose(fp);
    printf("Neural Net Learning: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
    printf("Neural Net Learning: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
}
void test_split()
@@ -326,30 +329,6 @@
    printf("%d, %d, %d\n", train.X.rows, split[0].X.rows, split[1].X.rows);
}
double *random_matrix(int rows, int cols)
{
    int i, j;
    double *m = calloc(rows*cols, sizeof(double));
    for(i = 0; i < rows; ++i){
        for(j = 0; j < cols; ++j){
            m[i*cols+j] = (double)rand()/RAND_MAX;
        }
    }
    return m;
}
void test_blas()
{
    int m = 1000, n = 1000, k = 1000;
    double *a = random_matrix(m,k);
    double *b = random_matrix(k,n);
    double *c = random_matrix(m,n);
    int i;
    for(i = 0; i<1000; ++i){
        gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
    }
}
void test_im2row()
{
    int h = 20;
@@ -362,16 +341,18 @@
    int mw = ((h-size)/stride+1)*((w-size)/stride+1);
    int mh = (size*size*c);
    int msize = mc*mw*mh;
    double *matrix = calloc(msize, sizeof(double));
    float *matrix = calloc(msize, sizeof(float));
    int i;
    for(i = 0; i < 1000; ++i){
        im2col_cpu(test.data,  c,  h,  w,  size,  stride, matrix);
        image render = double_to_image(mh, mw, mc, matrix);
        image render = float_to_image(mh, mw, mc, matrix);
    }
}
int main()
{
    //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW);
    //test_blas();
    //test_convolve_matrix();
    //    test_im2row();
src/utils.c
@@ -123,9 +123,9 @@
    return count;
}
double *parse_fields(char *line, int n)
float *parse_fields(char *line, int n)
{
    double *field = calloc(n, sizeof(double));
    float *field = calloc(n, sizeof(float));
    char *c, *p, *end;
    int count = 0;
    int done = 0;
@@ -143,36 +143,36 @@
    return field;
}
double mean_array(double *a, int n)
float mean_array(float *a, int n)
{
    int i;
    double sum = 0;
    float sum = 0;
    for(i = 0; i < n; ++i) sum += a[i];
    return sum/n;
}
double variance_array(double *a, int n)
float variance_array(float *a, int n)
{
    int i;
    double sum = 0;
    double mean = mean_array(a, n);
    float sum = 0;
    float mean = mean_array(a, n);
    for(i = 0; i < n; ++i) sum += (a[i] - mean)*(a[i]-mean);
    double variance = sum/n;
    float variance = sum/n;
    return variance;
}
double constrain(double a, double max)
float constrain(float a, float max)
{
    if(a > abs(max)) return abs(max);
    if(a < -abs(max)) return -abs(max);
    return a;
}
void normalize_array(double *a, int n)
void normalize_array(float *a, int n)
{
    int i;
    double mu = mean_array(a,n);
    double sigma = sqrt(variance_array(a,n));
    float mu = mean_array(a,n);
    float sigma = sqrt(variance_array(a,n));
    for(i = 0; i < n; ++i){
        a[i] = (a[i] - mu)/sigma;
    }
@@ -180,7 +180,7 @@
    sigma = sqrt(variance_array(a,n));
}
void translate_array(double *a, int n, double s)
void translate_array(float *a, int n, float s)
{
    int i;
    for(i = 0; i < n; ++i){
@@ -188,18 +188,18 @@
    }
}
void scale_array(double *a, int n, double s)
void scale_array(float *a, int n, float s)
{
    int i;
    for(i = 0; i < n; ++i){
        a[i] *= s;
    }
}
int max_index(double *a, int n)
int max_index(float *a, int n)
{
    if(n <= 0) return -1;
    int i, max_i = 0;
    double max = a[0];
    float max = a[0];
    for(i = 1; i < n; ++i){
        if(a[i] > max){
            max = a[i];
@@ -209,20 +209,20 @@
    return max_i;
}
double rand_normal()
float rand_normal()
{
    int i;
    double sum= 0;
    for(i = 0; i < 12; ++i) sum += (double)rand()/RAND_MAX;
    float sum= 0;
    for(i = 0; i < 12; ++i) sum += (float)rand()/RAND_MAX;
    return sum-6.;
}
double **one_hot_encode(double *a, int n, int k)
float **one_hot_encode(float *a, int n, int k)
{
    int i;
    double **t = calloc(n, sizeof(double*));
    float **t = calloc(n, sizeof(float*));
    for(i = 0; i < n; ++i){
        t[i] = calloc(k, sizeof(double));
        t[i] = calloc(k, sizeof(float));
        int index = (int)a[i];
        t[i][index] = 1;
    }
src/utils.h
@@ -13,15 +13,15 @@
list *parse_csv_line(char *line);
char *copy_string(char *s);
int count_fields(char *line);
double *parse_fields(char *line, int n);
void normalize_array(double *a, int n);
void scale_array(double *a, int n, double s);
void translate_array(double *a, int n, double s);
int max_index(double *a, int n);
double constrain(double a, double max);
double rand_normal();
double mean_array(double *a, int n);
double variance_array(double *a, int n);
double **one_hot_encode(double *a, int n, int k);
float *parse_fields(char *line, int n);
void normalize_array(float *a, int n);
void scale_array(float *a, int n, float s);
void translate_array(float *a, int n, float s);
int max_index(float *a, int n);
float constrain(float a, float max);
float rand_normal();
float mean_array(float *a, int n);
float variance_array(float *a, int n);
float **one_hot_encode(float *a, int n, int k);
#endif