From ad9dbfe16495204453b1b7f8593d320751f76ca0 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 10 Dec 2013 18:30:42 +0000
Subject: [PATCH] CSE546 submission
---
src/tests.c | 46 ++++++++++++++++++++++++++--------------------
1 files changed, 26 insertions(+), 20 deletions(-)
diff --git a/src/tests.c b/src/tests.c
index 0b9b5db..4638645 100644
--- a/src/tests.c
+++ b/src/tests.c
@@ -195,26 +195,31 @@
void test_nist()
{
srand(444444);
+ srand(888888);
network net = parse_network_cfg("nist.cfg");
data train = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
data test = load_categorical_data_csv("mnist/mnist_test.csv",0,10);
normalize_data_rows(train);
normalize_data_rows(test);
- randomize_data(train);
+ //randomize_data(train);
int count = 0;
double lr = .0005;
- while(++count <= 1){
- double acc = train_network_sgd(net, train, 10000, lr, .9, .001);
- printf("Training Accuracy: %lf\n", acc);
- lr /= 2;
+ double momentum = .9;
+ double decay = 0.01;
+ while(++count <= 1000){
+ double acc = train_network_sgd(net, train, 1000, lr, momentum, decay);
+ printf("Training Accuracy: %lf, Params: %f %f %f\n", acc, lr, momentum, decay);
+ visualize_network(net);
+ cvWaitKey(100);
+ //lr /= 2;
+ if(count%5 == 0 && 0){
+ double train_acc = network_accuracy(net, train);
+ fprintf(stderr, "\nTRAIN: %f\n", train_acc);
+ double test_acc = network_accuracy(net, test);
+ fprintf(stderr, "TEST: %f\n\n", test_acc);
+ printf("%d, %f, %f\n", count, train_acc, test_acc);
+ }
}
- double train_acc = network_accuracy(net, train);
- fprintf(stderr, "\nTRAIN: %f\n", train_acc);
- double test_acc = network_accuracy(net, test);
- fprintf(stderr, "TEST: %f\n\n", test_acc);
- printf("%d, %f, %f\n", count, train_acc, test_acc);
- //end = clock();
- //printf("Neural Net Learning: %lf seconds\n", (double)(end-start)/CLOCKS_PER_SEC);
}
void test_ensemble()
@@ -223,24 +228,25 @@
srand(888888);
data d = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
normalize_data_rows(d);
- randomize_data(d);
data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10);
normalize_data_rows(test);
data train = d;
/*
- data *split = split_data(d, 1, 10);
- data train = split[0];
- data test = split[1];
- */
+ data *split = split_data(d, 1, 10);
+ data train = split[0];
+ data test = split[1];
+ */
matrix prediction = make_matrix(test.y.rows, test.y.cols);
int n = 30;
for(i = 0; i < n; ++i){
int count = 0;
double lr = .0005;
+ double momentum = .9;
+ double decay = .01;
network net = parse_network_cfg("nist.cfg");
- while(++count <= 5){
- double acc = train_network_sgd(net, train, train.X.rows, lr, .9, .001);
- printf("Training Accuracy: %lf\n", acc);
+ while(++count <= 15){
+ double acc = train_network_sgd(net, train, train.X.rows, lr, momentum, decay);
+ printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay );
lr /= 2;
}
matrix partial = network_predict_data(net, test);
--
Gitblit v1.10.0