| | |
| | | return (y[class]?1:0); |
| | | } |
| | | |
| | | double train_network_sgd(network net, data d, double step, double momentum,double decay) |
| | | double train_network_sgd(network net, data d, int n, double step, double momentum,double decay) |
| | | { |
| | | int i; |
| | | int correct = 0; |
| | | for(i = 0; i < d.X.rows; ++i){ |
| | | for(i = 0; i < n; ++i){ |
| | | int index = rand()%d.X.rows; |
| | | correct += train_network_datum(net, d.X.vals[index], d.y.vals[index], step, momentum, decay); |
| | | if((i+1)%10 == 0){ |
| | | printf("%d: %f\n", (i+1), (double)correct/(i+1)); |
| | | } |
| | | //if((i+1)%10 == 0){ |
| | | // printf("%d: %f\n", (i+1), (double)correct/(i+1)); |
| | | //} |
| | | } |
| | | return (double)correct/d.X.rows; |
| | | return (double)correct/n; |
| | | } |
| | | |
| | | void train_network(network net, data d, double step, double momentum, double decay) |
| | |
| | | } |
| | | } |
| | | |
| | | double *network_predict(network net, double *input) |
| | | { |
| | | forward_network(net, input); |
| | | double *out = get_network_output(net); |
| | | return out; |
| | | } |
| | | |
| | | matrix network_predict_data(network net, data test) |
| | | { |
| | | int i,j; |
| | | int k = get_network_output_size(net); |
| | | matrix pred = make_matrix(test.X.rows, k); |
| | | for(i = 0; i < test.X.rows; ++i){ |
| | | double *out = network_predict(net, test.X.vals[i]); |
| | | for(j = 0; j < k; ++j){ |
| | | pred.vals[i][j] = out[j]; |
| | | } |
| | | } |
| | | return pred; |
| | | } |
| | | |
| | | void print_network(network net) |
| | | { |
| | | int i,j; |
| | |
| | | fprintf(stderr, "\n"); |
| | | } |
| | | } |
| | | |
| | | double network_accuracy(network net, data d) |
| | | { |
| | | int i; |
| | | int correct = 0; |
| | | int k = get_network_output_size(net); |
| | | for(i = 0; i < d.X.rows; ++i){ |
| | | forward_network(net, d.X.vals[i]); |
| | | double *out = get_network_output(net); |
| | | int guess = max_index(out, k); |
| | | if(d.y.vals[i][guess]) ++correct; |
| | | } |
| | | return (double)correct/d.X.rows; |
| | | matrix guess = network_predict_data(net, d); |
| | | double acc = matrix_accuracy(d.y, guess); |
| | | free_matrix(guess); |
| | | return acc; |
| | | } |
| | | |