| | |
| | | free_data(train); |
| | | } |
| | | |
| | | void train_full() |
| | | void train_assira() |
| | | { |
| | | network net = parse_network_cfg("cfg/imagenet.cfg"); |
| | | network net = parse_network_cfg("cfg/assira.cfg"); |
| | | srand(2222222); |
| | | int i = 0; |
| | | char *labels[] = {"cat","dog"}; |
| | | float lr = .00001; |
| | | float momentum = .9; |
| | | float decay = 0.01; |
| | | while(1){ |
| | | i += 1000; |
| | | data train = load_data_image_pathfile_random("images/assira/train.list", 1000, labels, 2, 256, 256); |
| | | //image im = float_to_image(256, 256, 3,train.X.vals[0]); |
| | | //visualize_network(net); |
| | | //cvWaitKey(100); |
| | | //show_image(im, "input"); |
| | | //cvWaitKey(100); |
| | | //scale_data_rows(train, 1./255.); |
| | | data train = load_data_image_pathfile_random("data/assira/train.list", 1000, labels, 2, 256, 256); |
| | | normalize_data_rows(train); |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd(net, train, 1000); |
| | | float loss = train_network_sgd_gpu(net, train, 10); |
| | | end = clock(); |
| | | printf("%d: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC, lr, momentum, decay); |
| | | printf("%d: %f, Time: %lf seconds\n", i, loss, (float)(end-start)/CLOCKS_PER_SEC ); |
| | | free_data(train); |
| | | if(i%10000==0){ |
| | | char buff[256]; |
| | |
| | | data train = load_all_cifar10(); |
| | | while(++count <= 10000){ |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd(net, train, iters); |
| | | float loss = train_network_sgd_gpu(net, train, iters); |
| | | end = clock(); |
| | | visualize_network(net); |
| | | cvWaitKey(5000); |
| | | //visualize_network(net); |
| | | //cvWaitKey(5000); |
| | | |
| | | //float test_acc = network_accuracy(net, test); |
| | | //printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay); |
| | |
| | | |
| | | int main(int argc, char *argv[]) |
| | | { |
| | | //train_full(); |
| | | //train_assira(); |
| | | //test_distribution(); |
| | | //feenableexcept(FE_DIVBYZERO | FE_INVALID | FE_OVERFLOW); |
| | | |