From aa5996d58e68edfbefe51061856aecd549dd09c4 Mon Sep 17 00:00:00 2001 From: Joseph Redmon <pjreddie@gmail.com> Date: Tue, 13 Jan 2015 01:27:08 +0000 Subject: [PATCH] Faster --- src/network.h | 21 +++++++++++++-------- 1 files changed, 13 insertions(+), 8 deletions(-) diff --git a/src/network.h b/src/network.h index 22e277c..7a401bd 100644 --- a/src/network.h +++ b/src/network.h @@ -36,24 +36,27 @@ } network; #ifdef GPU -void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train); -void backward_network_gpu(network net, cl_mem input); -void update_network_gpu(network net); -cl_mem get_network_output_cl_layer(network net, int i); -cl_mem get_network_delta_cl_layer(network net, int i); -float train_network_sgd_gpu(network net, data d, int n); +float train_network_datum_gpu(network net, float *x, float *y); +float *network_predict_gpu(network net, float *input); #endif +void compare_networks(network n1, network n2, data d); + network make_network(int n, int batch); void forward_network(network net, float *input, float *truth, int train); void backward_network(network net, float *input); void update_network(network net); -float train_network_sgd(network net, data d, int n); + +float train_network(network net, data d); float train_network_batch(network net, data d, int n); -void train_network(network net, data d); +float train_network_sgd(network net, data d, int n); + matrix network_predict_data(network net, data test); +float *network_predict(network net, float *input); float network_accuracy(network net, data d); +float *network_accuracies(network net, data d); float network_accuracy_multi(network net, data d, int n); +void top_predictions(network net, int n, int *index); float *get_network_output(network net); float *get_network_output_layer(network net, int i); float *get_network_delta_layer(network net, int i); @@ -66,6 +69,8 @@ void print_network(network net); void visualize_network(network net); int resize_network(network net, int h, int w, int c); +void set_batch_network(network *net, int b); +void set_learning_network(network *net, float rate, float momentum, float decay); int get_network_input_size(network net); float get_network_cost(network net); -- Gitblit v1.10.0