| | |
| | | free_data(train); |
| | | } |
| | | |
| | | void compare_nist(char *p1,char *p2) |
| | | { |
| | | srand(222222); |
| | | network n1 = parse_network_cfg(p1); |
| | | network n2 = parse_network_cfg(p2); |
| | | data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10); |
| | | normalize_data_rows(test); |
| | | compare_networks(n1, n2, test); |
| | | } |
| | | |
| | | void test_nist(char *path) |
| | | { |
| | | srand(222222); |
| | |
| | | normalize_data_rows(test); |
| | | int count = 0; |
| | | int iters = 60000/net.batch + 1; |
| | | while(++count <= 200){ |
| | | while(++count <= 10){ |
| | | clock_t start = clock(), end; |
| | | float loss = train_network_sgd(net, train, iters); |
| | | end = clock(); |
| | |
| | | else if(0==strcmp(argv[1], "visualize")) test_visualize(argv[2]); |
| | | else if(0==strcmp(argv[1], "valid")) validate_imagenet(argv[2]); |
| | | else if(0==strcmp(argv[1], "testnist")) test_nist(argv[2]); |
| | | else if(argc < 4){ |
| | | fprintf(stderr, "usage: %s <function> <filename> <filename>\n", argv[0]); |
| | | return 0; |
| | | } |
| | | else if(0==strcmp(argv[1], "compare")) compare_nist(argv[2], argv[3]); |
| | | fprintf(stderr, "Success!\n"); |
| | | return 0; |
| | | } |
| | |
| | | } |
| | | } |
| | | |
| | | void compare_networks(network n1, network n2, data test) |
| | | { |
| | | matrix g1 = network_predict_data(n1, test); |
| | | matrix g2 = network_predict_data(n2, test); |
| | | int i; |
| | | int a,b,c,d; |
| | | a = b = c = d = 0; |
| | | for(i = 0; i < g1.rows; ++i){ |
| | | int truth = max_index(test.y.vals[i], test.y.cols); |
| | | int p1 = max_index(g1.vals[i], g1.cols); |
| | | int p2 = max_index(g2.vals[i], g2.cols); |
| | | if(p1 == truth){ |
| | | if(p2 == truth) ++d; |
| | | else ++c; |
| | | }else{ |
| | | if(p2 == truth) ++b; |
| | | else ++a; |
| | | } |
| | | } |
| | | printf("%5d %5d\n%5d %5d\n", a, b, c, d); |
| | | } |
| | | |
| | | float network_accuracy(network net, data d) |
| | | { |
| | | matrix guess = network_predict_data(net, d); |
| | |
| | | float *network_predict_gpu(network net, float *input); |
| | | #endif |
| | | |
| | | void compare_networks(network n1, network n2, data d); |
| | | |
| | | network make_network(int n, int batch); |
| | | void forward_network(network net, float *input, float *truth, int train); |
| | | void backward_network(network net, float *input); |
| | |
| | | a[i] *= s; |
| | | } |
| | | } |
| | | |
| | | int max_index(float *a, int n) |
| | | { |
| | | if(n <= 0) return -1; |