Joseph Redmon
2014-01-23 1d53b6414e0cd81043d7c76aa89f4f97da5e479f
Stable on MNIST, about to change a lot
5 files modified
44 ■■■■ changed files
full.cfg 4 ●●●● patch | view | raw | blame | history
nist.cfg 4 ●●●● patch | view | raw | blame | history
src/network.c 18 ●●●●● patch | view | raw | blame | history
src/network.h 1 ●●●● patch | view | raw | blame | history
src/tests.c 17 ●●●● patch | view | raw | blame | history
full.cfg
@@ -11,10 +11,6 @@
stride=2
[conn]
output = 100
activation=ramp
[conn]
output = 2
activation=ramp
nist.cfg
@@ -2,7 +2,7 @@
width=28
height=28
channels=1
filters=5
filters=20
size=5
stride=1
activation=ramp
@@ -20,7 +20,7 @@
stride=2
[conn]
output = 100
output = 500
activation=ramp
[conn]
src/network.c
@@ -187,6 +187,24 @@
    }
    return (double)correct/n;
}
double train_network_batch(network net, data d, int n, double step, double momentum,double decay)
{
    int i;
    int correct = 0;
    for(i = 0; i < n; ++i){
        int index = rand()%d.X.rows;
        double *x = d.X.vals[index];
        double *y = d.y.vals[index];
        forward_network(net, x);
        int class = get_predicted_class_network(net);
        backward_network(net, x, y);
        correct += (y[class]?1:0);
    }
    update_network(net, step, momentum, decay);
    return (double)correct/n;
}
void train_network(network net, data d, double step, double momentum, double decay)
{
src/network.h
@@ -25,6 +25,7 @@
void backward_network(network net, double *input, double *truth);
void update_network(network net, double step, double momentum, double decay);
double train_network_sgd(network net, data d, int n, double step, double momentum,double decay);
double train_network_batch(network net, data d, int n, double step, double momentum,double decay);
void train_network(network net, data d, double step, double momentum, double decay);
matrix network_predict_data(network net, data test);
double network_accuracy(network net, data d);
src/tests.c
@@ -184,9 +184,12 @@
    srand(0);
    int i = 0;
    char *labels[] = {"cat","dog"};
    double lr = .00001;
    double momentum = .9;
    double decay = 0.01;
    while(i++ < 1000 || 1){
        data train = load_data_image_pathfile_random("train_paths.txt", 1000, labels, 2);
        train_network(net, train, .0005, 0, 0);
        train_network(net, train, lr, momentum, decay);
        free_data(train);
        printf("Round %d\n", i);
    }
@@ -206,9 +209,13 @@
    double lr = .0005;
    double momentum = .9;
    double decay = 0.01;
    clock_t start = clock(), end;
    while(++count <= 1000){
        double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
        printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
        double acc = train_network_sgd(net, train, 6400, lr, momentum, decay);
        printf("%5d Training Loss: %lf, Params: %f %f %f, ",count*100, 1.-acc, lr, momentum, decay);
        end = clock();
        printf("Time: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
        start=end;
        visualize_network(net);
        cvWaitKey(100);
        //lr /= 2; 
@@ -334,8 +341,8 @@
{
    //test_kernel_update();
    //test_split();
    test_ensemble();
    //test_nist();
    //test_ensemble();
    test_nist();
    //test_full();
    //test_random_preprocess();
    //test_random_classify();