| | |
| | | void train_assira() |
| | | { |
| | | network net = parse_network_cfg("cfg/assira.cfg"); |
| | | int imgs = 1000/net.batch+1; |
| | | //imgs = 1; |
| | | srand(2222222); |
| | | int i = 0; |
| | | char *labels[] = {"cat","dog"}; |
| | | while(1){ |
| | | i += 1000; |
| | | data train = load_data_image_pathfile_random("data/assira/train.list", 1000, labels, 2, 256, 256); |
| | | data train = load_data_image_pathfile_random("data/assira/train.list", imgs*net.batch, labels, 2, 256, 256); |
| | | normalize_data_rows(train); |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd_gpu(net, train, 10); |
| | | float loss = train_network_sgd_gpu(net, train, imgs); |
| | | end = clock(); |
| | | printf("%d: %f, Time: %lf seconds\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC ); |
| | | free_data(train); |
| | |
| | | data train = load_all_cifar10(); |
| | | while(++count <= 10000){ |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd_gpu(net, train, iters); |
| | | float loss = train_network_sgd(net, train, iters); |
| | | end = clock(); |
| | | //visualize_network(net); |
| | | //cvWaitKey(5000); |
| | |
| | | float test_acc = network_accuracy(net, test); |
| | | printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay); |
| | | char buff[256]; |
| | | sprintf(buff, "/home/pjreddie/cifar/cifar2_%d.cfg", count); |
| | | sprintf(buff, "/home/pjreddie/cifar/cifar10_2_%d.cfg", count); |
| | | save_network(net, buff); |
| | | }else{ |
| | | printf("%d: Loss: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, (float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay); |
| | |
| | | int iters = 10000/net.batch; |
| | | while(++count <= 2000){ |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd(net, train, iters); |
| | | float loss = train_network_sgd_gpu(net, train, iters); |
| | | end = clock(); |
| | | float test_acc = network_accuracy(net, test); |
| | | //float test_acc = 0; |
| | |
| | | |
| | | int main(int argc, char *argv[]) |
| | | { |
| | | //train_assira(); |
| | | //test_blas(); |
| | | train_assira(); |
| | | //test_distribution(); |
| | | //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); |
| | | |
| | |
| | | //test_ensemble(); |
| | | //test_nist_single(); |
| | | //test_nist(); |
| | | train_nist(); |
| | | //train_nist(); |
| | | //test_convolutional_layer(); |
| | | //test_col2im(); |
| | | //test_cifar10(); |