| | |
| | | free_data(train); |
| | | } |
| | | |
| | | void compare_nist(char *p1,char *p2) |
| | | { |
| | | srand(222222); |
| | | network n1 = parse_network_cfg(p1); |
| | | network n2 = parse_network_cfg(p2); |
| | | data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); |
| | | normalize_data_rows(test); |
| | | compare_networks(n1, n2, test); |
| | | } |
| | | |
| | | void test_nist(char *path) |
| | | { |
| | | srand(222222); |
| | |
| | | normalize_data_rows(test); |
| | | int count = 0; |
| | | int iters = 60000/net.batch + 1; |
| | | while(++count <= 200){ |
| | | while(++count <= 10){ |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd(net, train, iters); |
| | | end = clock(); |
| | |
| | | else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]); |
| | | else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]); |
| | | else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]); |
| | | else if(argc < 4){ |
| | | fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]); |
| | | return 0; |
| | | } |
| | | else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]); |
| | | fprintf(stderr, "Success!\n"); |
| | | return 0; |
| | | } |