Joseph Redmon
2014-12-18 f88baf4a3a756140cef3ca07be98cabb803d80ae
src/cnn.c
@@ -380,22 +380,24 @@
void train_nist(char *cfgfile)
{
    srand(222222);
    srand(time(0));
    network net = parse_network_cfg(cfgfile);
    // srand(time(0));
    data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
    data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
    normalize_data_rows(train);
    normalize_data_rows(test);
    network net = parse_network_cfg(cfgfile);
    int count = 0;
    int iters = 60000/net.batch + 1;
    while(++count <= 10){
        clock_t start = clock(), end;
        normalize_data_rows(train);
        normalize_data_rows(test);
        float loss = train_network_sgd(net, train, iters);
        end = clock();
        float test_acc = 0;
        //if(count%1 == 0) test_acc = network_accuracy(net, test);
        if(count%1 == 0) test_acc = network_accuracy(net, test);
        end = clock();
        printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC);
    }
    free_data(train);
    free_data(test);
    char buff[256];
    sprintf(buff, "%s.trained", cfgfile);
    save_network(net, buff);