From 08b757a0bf76efe8c76b453063a1bb19315bcaa6 Mon Sep 17 00:00:00 2001 From: Joseph Redmon <pjreddie@gmail.com> Date: Wed, 14 Jan 2015 20:18:57 +0000 Subject: [PATCH] Stable, needs to be way faster --- src/network.h | 21 ++++++++++++--------- 1 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/network.h b/src/network.h index 7625904..c6c7790 100644 --- a/src/network.h +++ b/src/network.h @@ -36,25 +36,26 @@ } network; #ifdef GPU -void forward_network_gpu(network net, cl_mem input, cl_mem truth, int train); -void backward_network_gpu(network net, cl_mem input); -void update_network_gpu(network net); -cl_mem get_network_output_cl_layer(network net, int i); -cl_mem get_network_delta_cl_layer(network net, int i); -float train_network_sgd_gpu(network net, data d, int n); -float train_network_data_gpu(network net, data d, int n); +float train_network_datum_gpu(network net, float *x, float *y); +float *network_predict_gpu(network net, float *input); #endif +void compare_networks(network n1, network n2, data d); +char *get_layer_string(LAYER_TYPE a); + network make_network(int n, int batch); void forward_network(network net, float *input, float *truth, int train); void backward_network(network net, float *input); void update_network(network net); -float train_network_sgd(network net, data d, int n); + +float train_network(network net, data d); float train_network_batch(network net, data d, int n); -void train_network(network net, data d); +float train_network_sgd(network net, data d, int n); + matrix network_predict_data(network net, data test); float *network_predict(network net, float *input); float network_accuracy(network net, data d); +float *network_accuracies(network net, data d); float network_accuracy_multi(network net, data d, int n); void top_predictions(network net, int n, int *index); float *get_network_output(network net); @@ -69,6 +70,8 @@ void print_network(network net); void visualize_network(network net); int resize_network(network net, int h, int w, int c); +void set_batch_network(network *net, int b); +void set_learning_network(network *net, float rate, float momentum, float decay); int get_network_input_size(network net); float get_network_cost(network net); -- Gitblit v1.10.0