Joseph Redmon
2014-05-09 cd8d53df21f3ad2810add2a8cff766c745f55a17
So there WAS this huge bug. Gone now
14 files modified
8 files added
908 ■■■■ changed files
Makefile 5 ●●●●● patch | view | raw | blame | history
src/activations.c 56 ●●●● patch | view | raw | blame | history
src/activations.cl 28 ●●●●● patch | view | raw | blame | history
src/activations.h 8 ●●●● patch | view | raw | blame | history
src/axpy.c 14 ●●●●● patch | view | raw | blame | history
src/axpy.cl patch | view | raw | blame | history
src/col2im.c patch | view | raw | blame | history
src/col2im.cl patch | view | raw | blame | history
src/connected_layer.c 32 ●●●●● patch | view | raw | blame | history
src/connected_layer.h 7 ●●●●● patch | view | raw | blame | history
src/convolutional_layer.c 107 ●●●● patch | view | raw | blame | history
src/convolutional_layer.h 23 ●●●●● patch | view | raw | blame | history
src/gemm.c 283 ●●●●● patch | view | raw | blame | history
src/im2col.c 121 ●●●●● patch | view | raw | blame | history
src/im2col.cl 26 ●●●●● patch | view | raw | blame | history
src/mini_blas.h 29 ●●●● patch | view | raw | blame | history
src/network.c 103 ●●●● patch | view | raw | blame | history
src/network.h 9 ●●●● patch | view | raw | blame | history
src/opencl.c 14 ●●●●● patch | view | raw | blame | history
src/opencl.h 8 ●●●● patch | view | raw | blame | history
src/parser.c 3 ●●●● patch | view | raw | blame | history
src/tests.c 32 ●●●●● patch | view | raw | blame | history
Makefile
@@ -1,18 +1,19 @@
CC=gcc
GPU=1
GPU=0
COMMON=-Wall -Werror -Wfatal-errors `pkg-config --cflags opencv` -I/usr/local/cuda/include/
ifeq ($(GPU), 1) 
COMMON+=-DGPU
else
endif
UNAME = $(shell uname)
OPTS=-O3 -flto
OPTS=-Ofast -flto
ifeq ($(UNAME), Darwin)
COMMON+= -isystem /usr/local/Cellar/opencv/2.4.6.1/include/opencv -isystem /usr/local/Cellar/opencv/2.4.6.1/include
ifeq ($(GPU), 1)
LDFLAGS= -framework OpenCL
endif
else
OPTS+= -march=native
ifeq ($(GPU), 1)
LDFLAGS= -lOpenCL
endif
src/activations.c
@@ -2,6 +2,7 @@
#include <math.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
char *get_activation_string(ACTIVATION a)
@@ -40,27 +41,29 @@
float ramp_activate(float x){return x*(x>0)+.1*x;}
float tanh_activate(float x){return (exp(2*x)-1)/(exp(2*x)+1);}
float activate(float x, ACTIVATION a){
float activate(float x, ACTIVATION a, float dropout)
{
    if((float)rand()/RAND_MAX < dropout) return 0;
    switch(a){
        case LINEAR:
            return linear_activate(x);
            return linear_activate(x)/(1-dropout);
        case SIGMOID:
            return sigmoid_activate(x);
            return sigmoid_activate(x)/(1-dropout);
        case RELU:
            return relu_activate(x);
            return relu_activate(x)/(1-dropout);
        case RAMP:
            return ramp_activate(x);
            return ramp_activate(x)/(1-dropout);
        case TANH:
            return tanh_activate(x);
            return tanh_activate(x)/(1-dropout);
    }
    return 0;
}
void activate_array(float *x, const int n, const ACTIVATION a)
void activate_array(float *x, const int n, const ACTIVATION a, float dropout)
{
    int i;
    for(i = 0; i < n; ++i){
        x[i] = activate(x[i], a);
        x[i] = activate(x[i], a, dropout);
    }
}
@@ -89,3 +92,40 @@
    }
#ifdef GPU
#include "opencl.h"
#include <math.h>
cl_kernel get_activation_kernel()
{
    static int init = 0;
    static cl_kernel kernel;
    if(!init){
        kernel = get_kernel("src/activations.cl", "activate_array", 0);
        init = 1;
    }
    return kernel;
}
void activate_array_ongpu(cl_mem x, int n, ACTIVATION a, float dropout)
{
    cl_setup();
    cl_kernel kernel = get_activation_kernel();
    cl_command_queue queue = cl.queue;
    cl_uint i = 0;
    cl.error = clSetKernelArg(kernel, i++, sizeof(x), (void*) &x);
    cl.error = clSetKernelArg(kernel, i++, sizeof(n), (void*) &n);
    cl.error = clSetKernelArg(kernel, i++, sizeof(a), (void*) &a);
    cl.error = clSetKernelArg(kernel, i++, sizeof(dropout),
        (void*) &dropout);
    check_error(cl);
    size_t gsize = n;
    clEnqueueNDRangeKernel(queue, kernel, 1, 0, &gsize, 0, 0, 0, 0);
    check_error(cl);
}
#endif
src/activations.cl
New file
@@ -0,0 +1,28 @@
typedef enum{
    SIGMOID, RELU, LINEAR, RAMP, TANH
}ACTIVATION;
float activate(float x, ACTIVATION a, float dropout)
{
    //if((float)rand()/RAND_MAX < dropout) return 0;
    switch(a){
        case LINEAR:
            return linear_activate(x)/(1-dropout);
        case SIGMOID:
            return sigmoid_activate(x)/(1-dropout);
        case RELU:
            return relu_activate(x)/(1-dropout);
        case RAMP:
            return ramp_activate(x)/(1-dropout);
        case TANH:
            return tanh_activate(x)/(1-dropout);
    }
    return 0;
}
__kernel void activate_array(__global float *x,
    const int n, const ACTIVATION a, const float dropout)
{
    int i = get_global_id(0);
    x[i] = activate(x[i], a, dropout);
}
src/activations.h
@@ -1,3 +1,4 @@
#include "opencl.h"
#ifndef ACTIVATIONS_H
#define ACTIVATIONS_H
@@ -8,10 +9,13 @@
ACTIVATION get_activation(char *s);
char *get_activation_string(ACTIVATION a);
float activate(float x, ACTIVATION a);
float activate(float x, ACTIVATION a, float dropout);
float gradient(float x, ACTIVATION a);
void gradient_array(const float *x, const int n, const ACTIVATION a, float *delta);
void activate_array(float *x, const int n, const ACTIVATION a);
void activate_array(float *x, const int n, const ACTIVATION a, float dropout);
#ifdef GPU
void activate_array_ongpu(cl_mem x, int n, ACTIVATION a, float dropout);
#endif
#endif
src/axpy.c
New file
@@ -0,0 +1,14 @@
#include "mini_blas.h"
void axpy_cpu(int N, float ALPHA, float *X, int INCX, float *Y, int INCY)
{
    int i;
    for(i = 0; i < N; ++i) Y[i*INCY] += ALPHA*X[i*INCX];
}
void scal_cpu(int N, float ALPHA, float *X, int INCX)
{
    int i;
    for(i = 0; i < N; ++i) X[i*INCX] *= ALPHA;
}
src/axpy.cl
src/col2im.c
src/col2im.cl
src/connected_layer.c
@@ -7,7 +7,7 @@
#include <stdlib.h>
#include <string.h>
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation)
connected_layer *make_connected_layer(int batch, int inputs, int outputs, float dropout, ACTIVATION activation)
{
    fprintf(stderr, "Connected Layer: %d inputs, %d outputs\n", inputs, outputs);
    int i;
@@ -15,6 +15,7 @@
    layer->inputs = inputs;
    layer->outputs = outputs;
    layer->batch=batch;
    layer->dropout = dropout;
    layer->output = calloc(batch*outputs, sizeof(float*));
    layer->delta = calloc(batch*outputs, sizeof(float*));
@@ -54,9 +55,9 @@
    memset(layer.weight_updates, 0, layer.outputs*layer.inputs*sizeof(float));
}
void forward_connected_layer(connected_layer layer, float *input)
void forward_connected_layer(connected_layer layer, float *input, int train)
{
    int i;
    if(!train) layer.dropout = 0;
    memcpy(layer.output, layer.biases, layer.outputs*sizeof(float));
    int m = layer.batch;
    int k = layer.inputs;
@@ -65,17 +66,15 @@
    float *b = layer.weights;
    float *c = layer.output;
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
    for(i = 0; i < layer.outputs*layer.batch; ++i){
        layer.output[i] = activate(layer.output[i], layer.activation);
    }
    activate_array(layer.output, layer.outputs*layer.batch, layer.activation, layer.dropout);
}
void learn_connected_layer(connected_layer layer, float *input)
void backward_connected_layer(connected_layer layer, float *input, float *delta)
{
    int i;
    for(i = 0; i < layer.outputs*layer.batch; ++i){
        layer.delta[i] *= gradient(layer.output[i], layer.activation);
        layer.bias_updates[i%layer.batch] += layer.delta[i]/layer.batch;
        layer.bias_updates[i%layer.batch] += layer.delta[i];
    }
    int m = layer.inputs;
    int k = layer.batch;
@@ -84,18 +83,15 @@
    float *b = layer.delta;
    float *c = layer.weight_updates;
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
}
void backward_connected_layer(connected_layer layer, float *input, float *delta)
{
    int m = layer.inputs;
    int k = layer.outputs;
    int n = layer.batch;
    m = layer.inputs;
    k = layer.outputs;
    n = layer.batch;
    float *a = layer.weights;
    float *b = layer.delta;
    float *c = delta;
    a = layer.weights;
    b = layer.delta;
    c = delta;
    gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
    if(c) gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
}
src/connected_layer.h
@@ -22,15 +22,16 @@
    float *output;
    float *delta;
    float dropout;
    ACTIVATION activation;
} connected_layer;
connected_layer *make_connected_layer(int batch, int inputs, int outputs, ACTIVATION activation);
connected_layer *make_connected_layer(int batch, int inputs, int outputs, float dropout, ACTIVATION activation);
void forward_connected_layer(connected_layer layer, float *input);
void forward_connected_layer(connected_layer layer, float *input, int train);
void backward_connected_layer(connected_layer layer, float *input, float *delta);
void learn_connected_layer(connected_layer layer, float *input);
void update_connected_layer(connected_layer layer, float step, float momentum, float decay);
src/convolutional_layer.c
@@ -55,7 +55,7 @@
    for(i = 0; i < c*n*size*size; ++i) layer->filters[i] = scale*(rand_uniform());
    for(i = 0; i < n; ++i){
        //layer->biases[i] = rand_normal()*scale + scale;
        layer->biases[i] = 0;
        layer->biases[i] = .5;
    }
    int out_h = convolutional_out_height(*layer);
    int out_w = convolutional_out_width(*layer);
@@ -63,6 +63,8 @@
    layer->col_image = calloc(layer->batch*out_h*out_w*size*size*c, sizeof(float));
    layer->output = calloc(layer->batch*out_h * out_w * n, sizeof(float));
    layer->delta  = calloc(layer->batch*out_h * out_w * n, sizeof(float));
    #ifdef GPU
    #endif
    layer->activation = activation;
    fprintf(stderr, "Convolutional Layer: %d x %d x %d image, %d filters -> %d x %d x %d image\n", h,w,c,n, out_h, out_w, n);
@@ -87,42 +89,70 @@
                                layer->batch*out_h * out_w * layer->n*sizeof(float));
}
void bias_output(const convolutional_layer layer)
{
    int i,j;
    int out_h = convolutional_out_height(layer);
    int out_w = convolutional_out_width(layer);
    for(i = 0; i < layer.n; ++i){
        for(j = 0; j < out_h*out_w; ++j){
            layer.output[i*out_h*out_w + j] = layer.biases[i];
        }
    }
}
void forward_convolutional_layer(const convolutional_layer layer, float *in)
{
    int i;
    int out_h = convolutional_out_height(layer);
    int out_w = convolutional_out_width(layer);
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
    int n = out_h*out_w*layer.batch;
    float *a = layer.filters;
    float *b = layer.col_image;
    float *c = layer.output;
    im2col_cpu(in,layer.batch, layer.c, layer.h, layer.w,
        layer.size, layer.stride, b);
    bias_output(layer);
    gemm(0,0,m,n,k,1,a,k,b,n,1,c,n);
    activate_array(layer.output, m*n, layer.activation, 0.);
}
#ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in)
{
    int m = layer.n;
    int k = layer.size*layer.size*layer.c;
    int n = convolutional_out_height(layer)*
            convolutional_out_width(layer)*
            layer.batch;
    float *a = layer.filters;
    float *b = layer.col_image;
    float *c = layer.output;
    for(i = 0; i < layer.batch; ++i){
        im2col_gpu(in+i*(n/layer.batch),  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b+i*(n/layer.batch));
    cl_write_array(layer.filters_cl, layer.filters, m*k);
    cl_mem a = layer.filters_cl;
    cl_mem b = layer.col_image_cl;
    cl_mem c = layer.output_cl;
    im2col_ongpu(in, layer.batch, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, b);
    gemm_ongpu(0,0,m,n,k,1,a,k,b,n,0,c,n);
    activate_array_ongpu(layer.output_cl, m*n, layer.activation, 0.);
    cl_read_array(layer.output_cl, layer.output, m*n);
    }
    gemm(0,0,m,n,k,1,a,k,b,n,0,c,n);
    activate_array(layer.output, m*n, layer.activation);
}
#endif
void learn_bias_convolutional_layer(convolutional_layer layer)
{
    int i,j,b;
    int i,b;
    int size = convolutional_out_height(layer)
                *convolutional_out_width(layer);
    for(b = 0; b < layer.batch; ++b){
        for(i = 0; i < layer.n; ++i){
            float sum = 0;
            for(j = 0; j < size; ++j){
                sum += layer.delta[j+size*(i+b*layer.n)];
            }
            layer.bias_updates[i] += sum/size;
            layer.bias_updates[i] += mean_array(layer.delta+size*(i+b*layer.n), size);
        }
    }
}
void learn_convolutional_layer(convolutional_layer layer)
void backward_convolutional_layer(convolutional_layer layer, float *delta)
{
    int m = layer.n;
    int n = layer.size*layer.size*layer.c;
@@ -137,20 +167,18 @@
    float *c = layer.filter_updates;
    gemm(0,1,m,n,k,1,a,k,b,k,1,c,n);
}
void backward_convolutional_layer(convolutional_layer layer, float *delta)
{
    if(delta){
    int i;
    int m = layer.size*layer.size*layer.c;
    int k = layer.n;
    int n = convolutional_out_height(layer)*
        m = layer.size*layer.size*layer.c;
        k = layer.n;
        n = convolutional_out_height(layer)*
            convolutional_out_width(layer)*
            layer.batch;
    float *a = layer.filters;
    float *b = layer.delta;
    float *c = layer.col_image;
        a = layer.filters;
        b = layer.delta;
        c = layer.col_image;
    gemm(1,0,m,n,k,1,a,m,b,n,0,c,n);
@@ -159,6 +187,7 @@
        col2im_cpu(c+i*n/layer.batch,  layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, delta+i*n/layer.batch);
    }
}
}
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay)
{
@@ -171,32 +200,6 @@
    scal_cpu(size, momentum, layer.filter_updates, 1);
}
void test_convolutional_layer()
{
    convolutional_layer l = *make_convolutional_layer(1,4,4,1,1,3,1,LINEAR);
    float input[] =    {1,2,3,4,
                        5,6,7,8,
                        9,10,11,12,
                        13,14,15,16};
    float filter[] =   {.5, 0, .3,
                        0  , 1,  0,
                        .2 , 0,  1};
    float delta[] =    {1, 2,
                        3,  4};
    float in_delta[] = {.5,1,.3,.6,
                        5,6,7,8,
                        9,10,11,12,
                        13,14,15,16};
    l.filters = filter;
    forward_convolutional_layer(l, input);
    l.delta = delta;
    learn_convolutional_layer(l);
    image filter_updates = float_to_image(3,3,1,l.filter_updates);
    print_image(filter_updates);
    printf("Delta:\n");
    backward_convolutional_layer(l, in_delta);
    pm(4,4,in_delta);
}
image get_convolutional_filter(convolutional_layer layer, int i)
{
src/convolutional_layer.h
@@ -1,6 +1,10 @@
#ifndef CONVOLUTIONAL_LAYER_H
#define CONVOLUTIONAL_LAYER_H
#ifdef GPU
#include "opencl.h"
#endif
#include "image.h"
#include "activations.h"
@@ -22,13 +26,30 @@
    float *delta;
    float *output;
    #ifdef GPU
    cl_mem filters_cl;
    cl_mem filter_updates_cl;
    cl_mem filter_momentum_cl;
    cl_mem biases_cl;
    cl_mem bias_updates_cl;
    cl_mem bias_momentum_cl;
    cl_mem col_image_cl;
    cl_mem delta_cl;
    cl_mem output_cl;
    #endif
    ACTIVATION activation;
} convolutional_layer;
#ifdef GPU
void forward_convolutional_layer_gpu(convolutional_layer layer, cl_mem in);
#endif
convolutional_layer *make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, ACTIVATION activation);
void resize_convolutional_layer(convolutional_layer *layer, int h, int w, int c);
void forward_convolutional_layer(const convolutional_layer layer, float *in);
void learn_convolutional_layer(convolutional_layer layer);
void update_convolutional_layer(convolutional_layer layer, float step, float momentum, float decay);
image *visualize_convolutional_layer(convolutional_layer layer, char *window, image *prev_filters);
src/gemm.c
New file
@@ -0,0 +1,283 @@
#include "mini_blas.h"
void gemm(int TA, int TB, int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
#ifdef GPU
    gemm_gpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
#else
    gemm_cpu( TA,  TB,  M, N, K, ALPHA,A,lda, B, ldb,BETA,C,ldc);
#endif
}
void gemm_nn(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[i*lda+k];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}
void gemm_nt(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            register float sum = 0;
            for(k = 0; k < K; ++k){
                sum += ALPHA*A[i*lda+k]*B[k+j*ldb];
            }
            C[i*ldc+j] += sum;
        }
    }
}
void gemm_tn(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(k = 0; k < K; ++k){
            register float A_PART = ALPHA*A[k*lda+i];
            for(j = 0; j < N; ++j){
                C[i*ldc+j] += A_PART*B[k*ldb+j];
            }
        }
    }
}
void gemm_tt(int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float *C, int ldc)
{
    int i,j,k;
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            for(k = 0; k < K; ++k){
                C[i*ldc+j] += ALPHA*A[i+k*lda]*B[k+j*ldb];
            }
        }
    }
}
void gemm_cpu(int TA, int TB, int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    int i, j;
    for(i = 0; i < M; ++i){
        for(j = 0; j < N; ++j){
            C[i*ldc + j] *= BETA;
        }
    }
    if(!TA && !TB)
        gemm_nn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(TA && !TB)
        gemm_tn(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else if(!TA && TB)
        gemm_nt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
    else
        gemm_tt(M, N, K, ALPHA,A,lda, B, ldb,C,ldc);
}
#ifdef GPU
#include "opencl.h"
#include <math.h>
#define STR_HELPER(x) #x
#define STR(x) STR_HELPER(x)
#define BLOCK 8
cl_kernel get_gemm_kernel()
{
    static int init = 0;
    static cl_kernel gemm_kernel;
    if(!init){
        gemm_kernel = get_kernel("src/gemm.cl", "gemm", "-D BLOCK=" STR(BLOCK) );
        init = 1;
    }
    return gemm_kernel;
}
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
        cl_mem A_gpu, int lda,
        cl_mem B_gpu, int ldb,
        float BETA,
        cl_mem C_gpu, int ldc)
{
    cl_setup();
    cl_kernel gemm_kernel = get_gemm_kernel();
    cl_command_queue queue = cl.queue;
    cl_uint i = 0;
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TA), (void*) &TA);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(TB), (void*) &TB);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(M), (void*) &M);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(N), (void*) &N);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(K), (void*) &K);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ALPHA), (void*) &ALPHA);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(A_gpu), (void*) &A_gpu);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(lda), (void*) &lda);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(B_gpu), (void*) &B_gpu);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldb), (void*) &ldb);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(BETA), (void*) &BETA);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(C_gpu), (void*) &C_gpu);
    cl.error = clSetKernelArg(gemm_kernel, i++, sizeof(ldc), (void*) &ldc);
    check_error(cl);
    const size_t global_size[] = {ceil((float)M/BLOCK)*BLOCK, ceil((float)N/BLOCK)*BLOCK};
    const size_t local_size[] = {BLOCK, BLOCK};
    clEnqueueNDRangeKernel(queue, gemm_kernel, 2, 0, global_size, local_size, 0, 0, 0);
    check_error(cl);
}
void gemm_gpu(int TA, int TB, int M, int N, int K, float ALPHA,
        float *A, int lda,
        float *B, int ldb,
        float BETA,
        float *C, int ldc)
{
    cl_setup();
    cl_context context = cl.context;
    cl_command_queue queue = cl.queue;
    size_t size = sizeof(float)*(TA ? lda*K:lda*M);
    cl_mem A_gpu = clCreateBuffer(context,
            CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
            size, A, &cl.error);
    check_error(cl);
    size = sizeof(float)*(TB ? ldb*N:ldb*K);
    cl_mem B_gpu = clCreateBuffer(context,
            CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
            size, B, &cl.error);
    check_error(cl);
    size = sizeof(float)*(ldc*M);
    cl_mem C_gpu = clCreateBuffer(context,
            CL_MEM_READ_WRITE|CL_MEM_COPY_HOST_PTR,
            size, C, &cl.error);
    check_error(cl);
    gemm_ongpu(TA, TB, M, N, K, ALPHA, A_gpu, lda, B_gpu, ldb, BETA, C_gpu, ldc);
    clEnqueueReadBuffer(queue, C_gpu, CL_TRUE, 0, size, C, 0, 0, 0);
    check_error(cl);
    clReleaseMemObject(A_gpu);
    clReleaseMemObject(B_gpu);
    clReleaseMemObject(C_gpu);
}
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
void time_gpu_random_matrix(int TA, int TB, int m, int k, int n)
{
    float *a;
    if(!TA) a = random_matrix(m,k);
    else a = random_matrix(k,m);
    int lda = (!TA)?k:m;
    float *b;
    if(!TB) b = random_matrix(k,n);
    else b = random_matrix(n,k);
    int ldb = (!TB)?n:k;
    float *c = random_matrix(m,n);
    int i;
    clock_t start = clock(), end;
    for(i = 0; i<1000; ++i){
        gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
    }
    end = clock();
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %lf ms\n",m,k,k,n, TA, TB, (float)(end-start)/CLOCKS_PER_SEC);
    free(a);
    free(b);
    free(c);
}
void test_gpu_accuracy(int TA, int TB, int m, int k, int n)
{
    srand(0);
    float *a;
    if(!TA) a = random_matrix(m,k);
    else a = random_matrix(k,m);
    int lda = (!TA)?k:m;
    float *b;
    if(!TB) b = random_matrix(k,n);
    else b = random_matrix(n,k);
    int ldb = (!TB)?n:k;
    float *c = random_matrix(m,n);
    float *c_gpu = random_matrix(m,n);
    memset(c, 0, m*n*sizeof(float));
    memset(c_gpu, 0, m*n*sizeof(float));
    int i;
    //pm(m,k,b);
    gemm_gpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c_gpu,n);
    //pm(m, n, c_gpu);
    gemm_cpu(TA,TB,m,n,k,1,a,lda,b,ldb,1,c,n);
    //pm(m, n, c);
    double sse = 0;
    for(i = 0; i < m*n; ++i) {
        //printf("%f %f\n", c[i], c_gpu[i]);
        sse += pow(c[i]-c_gpu[i], 2);
    }
    printf("Matrix Multiplication %dx%d * %dx%d, TA=%d, TB=%d: %g MSE\n",m,k,k,n, TA, TB, sse/(m*n));
    free(a);
    free(b);
    free(c);
}
void test_gpu_blas()
{
    test_gpu_accuracy(0,0,17,10,10);
    test_gpu_accuracy(1,0,17,10,10);
    test_gpu_accuracy(0,1,17,10,10);
    test_gpu_accuracy(1,1,17,10,10);
    test_gpu_accuracy(0,0,1000,10,100);
    test_gpu_accuracy(1,0,1000,10,100);
    test_gpu_accuracy(0,1,1000,10,100);
    test_gpu_accuracy(1,1,1000,10,100);
    time_gpu_random_matrix(0,0,1000,1000,100);
    time_random_matrix(0,0,1000,1000,100);
    time_gpu_random_matrix(0,1,1000,1000,100);
    time_random_matrix(0,1,1000,1000,100);
    time_gpu_random_matrix(1,0,1000,1000,100);
    time_random_matrix(1,0,1000,1000,100);
    time_gpu_random_matrix(1,1,1000,1000,100);
    time_random_matrix(1,1,1000,1000,100);
}
#endif
src/im2col.c
New file
@@ -0,0 +1,121 @@
#include "mini_blas.h"
//From Berkeley Vision's Caffe!
//https://github.com/BVLC/caffe/blob/master/LICENSE
void im2col_cpu(float* data_im,
    const int batch, const int channels, const int height, const int width,
    const int ksize, const int stride, float* data_col)
{
    int c,h,w,b;
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    int im_size = height*width*channels;
    int col_size = height_col*width_col*channels_col;
    for(b = 0; b < batch; ++b){
        for ( c = 0; c < channels_col; ++c) {
            int w_offset = c % ksize;
            int h_offset = (c / ksize) % ksize;
            int c_im = c / ksize / ksize;
            for ( h = 0; h < height_col; ++h) {
                for ( w = 0; w < width_col; ++w) {
                    data_col[(c * height_col + h) * width_col + w] =
                    data_im[(c_im * height + h * stride + h_offset) * width
                        + w * stride + w_offset];
                }
            }
        }
        data_im += im_size;
        data_col+= col_size;
    }
}
#ifdef GPU
#include "opencl.h"
#include <math.h>
cl_kernel get_im2col_kernel()
{
    static int init = 0;
    static cl_kernel im2col_kernel;
    if(!init){
        im2col_kernel = get_kernel("src/im2col.cl", "im2col", 0);
        init = 1;
    }
    return im2col_kernel;
}
void im2col_ongpu(cl_mem data_im, const int batch,
        const int channels, const int height, const int width,
        const int ksize, const int stride, cl_mem data_col)
{
    cl_setup();
    cl_kernel im2col_kernel = get_im2col_kernel();
    cl_command_queue queue = cl.queue;
    cl_uint i = 0;
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_im), (void*) &data_im);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(batch), (void*) &batch);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(channels), (void*) &channels);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(height), (void*) &height);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(width), (void*) &width);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(ksize), (void*) &ksize);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(stride), (void*) &stride);
    cl.error = clSetKernelArg(im2col_kernel, i++, sizeof(data_col), (void*) &data_col);
    check_error(cl);
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    size_t global_size[2];
    size_t local_size[2];
    global_size[0] = batch;
    global_size[1] = channels_col;
    local_size[0] = height_col;
    local_size[1] = width_col;
    clEnqueueNDRangeKernel(queue, im2col_kernel, 2, 0,
            global_size, local_size, 0, 0, 0);
    check_error(cl);
}
void im2col_gpu(float *data_im,
        const int batch, const int channels, const int height, const int width,
        const int ksize, const int stride,
        float *data_col)
{
    cl_setup();
    cl_context context = cl.context;
    cl_command_queue queue = cl.queue;
    size_t size = sizeof(float)*(channels*height*width*batch);
    cl_mem im_gpu = clCreateBuffer(context,
            CL_MEM_READ_ONLY|CL_MEM_COPY_HOST_PTR,
            size, data_im, &cl.error);
    check_error(cl);
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    size = sizeof(float)*(height_col*width_col*channels_col*batch);
    cl_mem col_gpu = clCreateBuffer(context,
            CL_MEM_WRITE_ONLY|CL_MEM_COPY_HOST_PTR,
            size, data_col, &cl.error);
    check_error(cl);
    im2col_ongpu(im_gpu, batch, channels, height, width,
            ksize, stride, col_gpu);
    clEnqueueReadBuffer(queue, col_gpu, CL_TRUE, 0, size, data_col, 0, 0, 0);
    check_error(cl);
    clReleaseMemObject(col_gpu);
    clReleaseMemObject(im_gpu);
}
#endif
src/im2col.cl
New file
@@ -0,0 +1,26 @@
__kernel void im2col(__global float *data_im,
    const int batch, const int channels, const int height, const int width,
    const int ksize, const int stride, __global float *data_col)
{
    int b = get_global_id(0);
    int c = get_global_id(1);
    int h = get_local_id(0);
    int w = get_local_id(1);
    int height_col = (height - ksize) / stride + 1;
    int width_col = (width - ksize) / stride + 1;
    int channels_col = channels * ksize * ksize;
    int im_offset = height*width*channels*b;
    int col_offset = height_col*width_col*channels_col*b;
    int w_offset = c % ksize;
    int h_offset = (c / ksize) % ksize;
    int c_im = c / ksize / ksize;
    data_col[(c * height_col + h) * width_col + w + col_offset] =
        data_im[(c_im * height + h * stride + h_offset) * width
        + w * stride + w_offset + im_offset];
}
src/mini_blas.h
@@ -1,3 +1,5 @@
#include "opencl.h"
void pm(int M, int N, float *A);
void gemm(int TA, int TB, int M, int N, int K, float ALPHA, 
                    float *A, int lda, 
@@ -6,12 +8,27 @@
                    float *C, int ldc);
float *random_matrix(int rows, int cols);
void time_random_matrix(int TA, int TB, int m, int k, int n);
void im2col_gpu(float* data_im, const int channels,
        const int height, const int width, const int ksize, const int stride,
        float* data_col);
void im2col_cpu(float* data_im, const int channels,
        const int height, const int width, const int ksize, const int stride,
        float* data_col);
#ifdef GPU
void im2col_ongpu(cl_mem data_im, const int batch,
        const int channels, const int height, const int width,
        const int ksize, const int stride, cl_mem data_col);
void im2col_gpu(float *data_im,
    const int batch, const int channels, const int height, const int width,
    const int ksize, const int stride, float *data_col);
void gemm_ongpu(int TA, int TB, int M, int N, int K, float ALPHA,
        cl_mem A_gpu, int lda,
        cl_mem B_gpu, int ldb,
        float BETA,
        cl_mem C_gpu, int ldc);
#endif
void im2col_cpu(float* data_im,
    const int batch, const int channels, const int height, const int width,
    const int ksize, const int stride, float* data_col);
void col2im_cpu(float* data_col, const int channels,
        const int height, const int width, const int ksize, const int stride,
        float* data_im);
src/network.c
@@ -19,6 +19,9 @@
    net.types = calloc(net.n, sizeof(LAYER_TYPE));
    net.outputs = 0;
    net.output = 0;
    #ifdef GPU
    net.input_cl = 0;
    #endif
    return net;
}
@@ -40,17 +43,6 @@
    fprintf(fp, "data=");
    for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
    for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
    /*
    int j,k;
    for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
    for(i = 0; i < l->n; ++i){
        for(j = l->c-1; j >= 0; --j){
            for(k = 0; k < l->size*l->size; ++k){
                fprintf(fp, "%g,", l->filters[i*(l->c*l->size*l->size)+j*l->size*l->size+k]);
            }
        }
    }
    */
    fprintf(fp, "\n\n");
}
void print_connected_cfg(FILE *fp, connected_layer *l, int first)
@@ -121,18 +113,34 @@
    fclose(fp);
}
void forward_network(network net, float *input)
void forward_network(network net, float *input, int train)
{
    int i;
    #ifdef GPU
    cl_setup();
    size_t size = get_network_input_size(net);
    if(!net.input_cl){
        net.input_cl = clCreateBuffer(cl.context,
            CL_MEM_READ_WRITE, size*sizeof(float), 0, &cl.error);
        check_error(cl);
    }
    cl_write_array(net.input_cl, input, size);
    cl_mem input_cl = net.input_cl;
    #endif
    for(i = 0; i < net.n; ++i){
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            #ifdef GPU
            forward_convolutional_layer_gpu(layer, input_cl);
            input_cl = layer.output_cl;
            #else
            forward_convolutional_layer(layer, input);
            #endif
            input = layer.output;
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            forward_connected_layer(layer, input);
            forward_connected_layer(layer, input, train);
            input = layer.output;
        }
        else if(net.types[i] == SOFTMAX){
@@ -263,9 +271,7 @@
        }
        if(net.types[i] == CONVOLUTIONAL){
            convolutional_layer layer = *(convolutional_layer *)net.layers[i];
            learn_convolutional_layer(layer);
            //learn_convolutional_layer(layer);
            if(i != 0) backward_convolutional_layer(layer, prev_delta);
            backward_convolutional_layer(layer, prev_delta);
        }
        else if(net.types[i] == MAXPOOL){
            maxpool_layer layer = *(maxpool_layer *)net.layers[i];
@@ -281,8 +287,7 @@
        }
        else if(net.types[i] == CONNECTED){
            connected_layer layer = *(connected_layer *)net.layers[i];
            learn_connected_layer(layer, prev_input);
            if(i != 0) backward_connected_layer(layer, prev_input, prev_delta);
            backward_connected_layer(layer, prev_input, prev_delta);
        }
    }
    return error;
@@ -290,7 +295,7 @@
float train_network_datum(network net, float *x, float *y, float step, float momentum, float decay)
{
    forward_network(net, x);
    forward_network(net, x, 1);
    //int class = get_predicted_class_network(net);
    float error = backward_network(net, x, y);
    update_network(net, step, momentum, decay);
@@ -332,7 +337,7 @@
        int index = rand()%d.X.rows;
        float *x = d.X.vals[index];
        float *y = d.y.vals[index];
        forward_network(net, x);
        forward_network(net, x, 1);
        int class = get_predicted_class_network(net);
        backward_network(net, x, y);
        correct += (y[class]?1:0);
@@ -359,6 +364,27 @@
    fprintf(stderr, "Accuracy: %f\n", (float)correct/d.X.rows);
}
int get_network_input_size_layer(network net, int i)
{
    if(net.types[i] == CONVOLUTIONAL){
        convolutional_layer layer = *(convolutional_layer *)net.layers[i];
        return layer.h*layer.w*layer.c;
    }
    else if(net.types[i] == MAXPOOL){
        maxpool_layer layer = *(maxpool_layer *)net.layers[i];
        return layer.h*layer.w*layer.c;
    }
    else if(net.types[i] == CONNECTED){
        connected_layer layer = *(connected_layer *)net.layers[i];
        return layer.inputs;
    }
    else if(net.types[i] == SOFTMAX){
        softmax_layer layer = *(softmax_layer *)net.layers[i];
        return layer.inputs;
    }
    return 0;
}
int get_network_output_size_layer(network net, int i)
{
    if(net.types[i] == CONVOLUTIONAL){
@@ -382,36 +408,6 @@
    return 0;
}
/*
   int resize_network(network net, int h, int w, int c)
   {
   int i;
   for (i = 0; i < net.n; ++i){
   if(net.types[i] == CONVOLUTIONAL){
   convolutional_layer *layer = (convolutional_layer *)net.layers[i];
   layer->h = h;
   layer->w = w;
   layer->c = c;
   image output = get_convolutional_image(*layer);
   h = output.h;
   w = output.w;
   c = output.c;
   }
   else if(net.types[i] == MAXPOOL){
   maxpool_layer *layer = (maxpool_layer *)net.layers[i];
   layer->h = h;
   layer->w = w;
   layer->c = c;
   image output = get_maxpool_image(*layer);
   h = output.h;
   w = output.w;
   c = output.c;
   }
   }
   return 0;
   }
 */
int resize_network(network net, int h, int w, int c)
{
    int i;
@@ -450,6 +446,11 @@
    return get_network_output_size_layer(net, i);
}
int get_network_input_size(network net)
{
    return get_network_output_size_layer(net, 0);
}
image get_network_image_layer(network net, int i)
{
    if(net.types[i] == CONVOLUTIONAL){
@@ -497,7 +498,7 @@
float *network_predict(network net, float *input)
{
    forward_network(net, input);
    forward_network(net, input, 0);
    float *out = get_network_output(net);
    return out;
}
src/network.h
@@ -2,6 +2,7 @@
#ifndef NETWORK_H
#define NETWORK_H
#include "opencl.h"
#include "image.h"
#include "data.h"
@@ -20,10 +21,15 @@
    LAYER_TYPE *types;
    int outputs;
    float *output;
    #ifdef GPU
    cl_mem input_cl;
    cl_mem output_cl;
    #endif
} network;
network make_network(int n, int batch);
void forward_network(network net, float *input);
void forward_network(network net, float *input, int train);
float backward_network(network net, float *input, float *truth);
void update_network(network net, float step, float momentum, float decay);
float train_network_sgd(network net, data d, int n, float step, float momentum,float decay);
@@ -44,6 +50,7 @@
void visualize_network(network net);
void save_network(network net, char *filename);
int resize_network(network net, int h, int w, int c);
int get_network_input_size(network net);
#endif
src/opencl.c
@@ -88,4 +88,18 @@
    return kernel;
}
void cl_read_array(cl_mem mem, float *x, int n)
{
    cl_setup();
    clEnqueueReadBuffer(cl.queue, mem, CL_TRUE, 0, sizeof(float)*n,x,0,0,0);
    check_error(cl);
}
void cl_write_array(cl_mem mem, float *x, int n)
{
    cl_setup();
    clEnqueueWriteBuffer(cl.queue, mem, CL_TRUE, 0,sizeof(float)*n,x,0,0,0);
    check_error(cl);
}
#endif
src/opencl.h
@@ -1,3 +1,6 @@
#ifdef GPU
#ifndef OPENCL_H
#define OPENCL_H
#ifdef __APPLE__
#include <OpenCL/opencl.h>
#else
@@ -18,4 +21,7 @@
void cl_setup();
void check_error(cl_info info);
cl_kernel get_kernel(char *filename, char *kernelname, char *options);
void cl_read_array(cl_mem mem, float *x, int n);
void cl_write_array(cl_mem mem, float *x, int n);
#endif
#endif
src/parser.c
@@ -89,6 +89,7 @@
    int i;
    int input;
    int output = option_find_int(options, "output",1);
    float dropout = option_find_float(options, "dropout", 0.);
    char *activation_s = option_find_str(options, "activation", "sigmoid");
    ACTIVATION activation = get_activation(activation_s);
    if(count == 0){
@@ -97,7 +98,7 @@
    }else{
        input =  get_network_output_size_layer(net, count-1);
    }
    connected_layer *layer = make_connected_layer(net.batch, input, output, activation);
    connected_layer *layer = make_connected_layer(net.batch, input, output, dropout, activation);
    char *data = option_find_str(options, "data", 0);
    if(data){
        char *curr = data;
src/tests.c
@@ -52,7 +52,7 @@
    int i;
    clock_t start = clock(), end;
    for(i = 0; i < 1000; ++i){
        im2col_cpu(dog.data,  dog.c,  dog.h,  dog.w,  size,  stride, matrix);
        im2col_cpu(dog.data,  1, dog.c,  dog.h,  dog.w,  size,  stride, matrix);
        gemm(0,0,n,mw,mh,1,filters,mh,matrix,mw,1,edge.data,mw);
    }
    end = clock();
@@ -168,7 +168,7 @@
        float v = ((float)rand()/RAND_MAX);
        float truth = v*v;
        input[0] = v;
        forward_network(net, input);
        forward_network(net, input, 1);
        float *out = get_network_output(net);
        float *delta = get_network_delta(net);
        float err = pow((out[0]-truth),2.);
@@ -245,7 +245,7 @@
        normalize_data_rows(test);
        for(j = 0; j < test.X.rows; ++j){
            float *x = test.X.vals[j];
            forward_network(net, x);
            forward_network(net, x, 0);
            int class = get_predicted_class_network(net);
            fprintf(fp, "%d\n", class);
        }
@@ -317,22 +317,14 @@
    int batch = 10000;
    while(++count <= 10000){
        float loss = train_network_sgd(net, train, batch, lr, momentum, decay);
        printf("%5f %5f\n",(double)count*batch/train.X.rows, loss);
        float test_acc = network_accuracy(net, test);
        printf("%3d %5f %5f\n",count, loss, test_acc);
        //printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*1000, loss, lr, momentum, decay);
        //end = clock();
        //printf("Time: %lf seconds\n", (float)(end-start)/CLOCKS_PER_SEC);
        //start=end;
        /*
           if(count%5 == 0){
           float train_acc = network_accuracy(net, train);
           fprintf(stderr, "\nTRAIN: %f\n", train_acc);
           float test_acc = network_accuracy(net, test);
           fprintf(stderr, "TEST: %f\n\n", test_acc);
           printf("%d, %f, %f\n", count, train_acc, test_acc);
        //lr *= .5;
        }
        */
    }
}
void test_ensemble()
@@ -387,7 +379,7 @@
            int index = rand()%m.rows;
            //image p = float_to_image(1690,1,1,m.vals[index]);
            //normalize_image(p);
            forward_network(net, m.vals[index]);
            forward_network(net, m.vals[index], 1);
            float *out = get_network_output(net);
            float *delta = get_network_delta(net);
            //printf("%f\n", out[0]);
@@ -408,7 +400,7 @@
    matrix test = csv_to_matrix("test.csv");
    truth = pop_column(&test, 0);
    for(i = 0; i < test.rows; ++i){
        forward_network(net, test.vals[i]);
        forward_network(net, test.vals[i], 0);
        float *out = get_network_output(net);
        if(fabs(out[0]) < .5) fprintf(fp, "0\n");
        else fprintf(fp, "1\n");
@@ -439,7 +431,7 @@
    float *matrix = calloc(msize, sizeof(float));
    int i;
    for(i = 0; i < 1000; ++i){
        im2col_cpu(test.data,  c,  h,  w,  size,  stride, matrix);
        im2col_cpu(test.data, 1, c,  h,  w,  size,  stride, matrix);
        //image render = float_to_image(mh, mw, mc, matrix);
    }
}
@@ -506,7 +498,7 @@
    //normalize_array(im.data, im.h*im.w*im.c);
    translate_image(im, -144);
    resize_network(net, im.h, im.w, im.c);
    forward_network(net, im.data);
    forward_network(net, im.data, 0);
    image out = get_network_image(net);
    free_image(im);
    cvReleaseImage(&sized);
@@ -558,7 +550,7 @@
        resize_network(net, im.h, im.w, im.c);
        //scale_image(im, 1./255);
        translate_image(im, -144);
        forward_network(net, im.data);
        forward_network(net, im.data, 0);
        image out = get_network_image(net);
        int dh = (im.h - h)/(out.h-1);
@@ -620,7 +612,7 @@
        image im = load_image(image_path, 0, 0);
        printf("Processing %dx%d image\n", im.h, im.w);
        resize_network(net, im.h, im.w, im.c);
        forward_network(net, im.data);
        forward_network(net, im.data, 0);
        image out = get_network_image(net);
        int dh = (im.h - h)/h;
@@ -653,7 +645,7 @@
    image im = load_image("data/cat.png", 0, 0);
    printf("Processing %dx%d image\n", im.h, im.w);
    resize_network(net, im.h, im.w, im.c);
    forward_network(net, im.data);
    forward_network(net, im.data, 0);
    visualize_network(net);
    cvWaitKey(0);