From 1a35f49ab3ae9d74b636de8580e07ab072846ea9 Mon Sep 17 00:00:00 2001
From: Alexey <AlexeyAB@users.noreply.github.com>
Date: Thu, 31 May 2018 11:19:42 +0000
Subject: [PATCH] Merge pull request #950 from IlyaOvodov/Fix_size_PR
---
src/network_kernels.cu | 374 +++++++++++++++++++++++++++++-----------------------
1 files changed, 209 insertions(+), 165 deletions(-)
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 3e0c2b6..a11d61f 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -22,7 +22,6 @@
#include "region_layer.h"
#include "convolutional_layer.h"
#include "activation_layer.h"
-#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
#include "reorg_layer.h"
#include "avgpool_layer.h"
@@ -37,6 +36,10 @@
#include "blas.h"
}
+#ifdef OPENCV
+#include "opencv2/highgui/highgui_c.h"
+#endif
+
float * get_network_output_gpu_layer(network net, int i);
float * get_network_delta_gpu_layer(network net, int i);
float * get_network_output_gpu(network net);
@@ -51,50 +54,25 @@
if(l.delta_gpu){
fill_ongpu(l.outputs * l.batch, 0, l.delta_gpu, 1);
}
- if(l.type == CONVOLUTIONAL){
- forward_convolutional_layer_gpu(l, state);
- } else if(l.type == DECONVOLUTIONAL){
- forward_deconvolutional_layer_gpu(l, state);
- } else if(l.type == ACTIVE){
- forward_activation_layer_gpu(l, state);
- } else if(l.type == LOCAL){
- forward_local_layer_gpu(l, state);
- } else if(l.type == DETECTION){
- forward_detection_layer_gpu(l, state);
- } else if(l.type == REGION){
- forward_region_layer_gpu(l, state);
- } else if(l.type == CONNECTED){
- forward_connected_layer_gpu(l, state);
- } else if(l.type == RNN){
- forward_rnn_layer_gpu(l, state);
- } else if(l.type == GRU){
- forward_gru_layer_gpu(l, state);
- } else if(l.type == CRNN){
- forward_crnn_layer_gpu(l, state);
- } else if(l.type == CROP){
- forward_crop_layer_gpu(l, state);
- } else if(l.type == COST){
- forward_cost_layer_gpu(l, state);
- } else if(l.type == SOFTMAX){
- forward_softmax_layer_gpu(l, state);
- } else if(l.type == NORMALIZATION){
- forward_normalization_layer_gpu(l, state);
- } else if(l.type == BATCHNORM){
- forward_batchnorm_layer_gpu(l, state);
- } else if(l.type == MAXPOOL){
- forward_maxpool_layer_gpu(l, state);
- } else if(l.type == REORG){
- forward_reorg_layer_gpu(l, state);
- } else if(l.type == AVGPOOL){
- forward_avgpool_layer_gpu(l, state);
- } else if(l.type == DROPOUT){
- forward_dropout_layer_gpu(l, state);
- } else if(l.type == ROUTE){
- forward_route_layer_gpu(l, net);
- } else if(l.type == SHORTCUT){
- forward_shortcut_layer_gpu(l, state);
- }
+ l.forward_gpu(l, state);
+ if(net.wait_stream)
+ cudaStreamSynchronize(get_cuda_stream());
state.input = l.output_gpu;
+/*
+ cuda_pull_array(l.output_gpu, l.output, l.batch*l.outputs);
+ if (l.out_w >= 0 && l.out_h >= 1 && l.c >= 3) {
+ int j;
+ for (j = 0; j < l.out_c; ++j) {
+ image img = make_image(l.out_w, l.out_h, 3);
+ memcpy(img.data, l.output+ l.out_w*l.out_h*j, l.out_w*l.out_h * 1 * sizeof(float));
+ char buff[256];
+ sprintf(buff, "layer-%d slice-%d", i, j);
+ show_image(img, buff);
+ }
+ cvWaitKey(0); // wait press-key in console
+ cvDestroyAllWindows();
+ }
+*/
}
}
@@ -107,6 +85,7 @@
for(i = net.n-1; i >= 0; --i){
state.index = i;
layer l = net.layers[i];
+ if (l.stopbackward) break;
if(i == 0){
state.input = original_input;
state.delta = original_delta;
@@ -115,71 +94,21 @@
state.input = prev.output_gpu;
state.delta = prev.delta_gpu;
}
- if(l.type == CONVOLUTIONAL){
- backward_convolutional_layer_gpu(l, state);
- } else if(l.type == DECONVOLUTIONAL){
- backward_deconvolutional_layer_gpu(l, state);
- } else if(l.type == ACTIVE){
- backward_activation_layer_gpu(l, state);
- } else if(l.type == LOCAL){
- backward_local_layer_gpu(l, state);
- } else if(l.type == MAXPOOL){
- if(i != 0) backward_maxpool_layer_gpu(l, state);
- } else if(l.type == REORG){
- backward_reorg_layer_gpu(l, state);
- } else if(l.type == AVGPOOL){
- if(i != 0) backward_avgpool_layer_gpu(l, state);
- } else if(l.type == DROPOUT){
- backward_dropout_layer_gpu(l, state);
- } else if(l.type == DETECTION){
- backward_detection_layer_gpu(l, state);
- } else if(l.type == REGION){
- backward_region_layer_gpu(l, state);
- } else if(l.type == NORMALIZATION){
- backward_normalization_layer_gpu(l, state);
- } else if(l.type == BATCHNORM){
- backward_batchnorm_layer_gpu(l, state);
- } else if(l.type == SOFTMAX){
- if(i != 0) backward_softmax_layer_gpu(l, state);
- } else if(l.type == CONNECTED){
- backward_connected_layer_gpu(l, state);
- } else if(l.type == RNN){
- backward_rnn_layer_gpu(l, state);
- } else if(l.type == GRU){
- backward_gru_layer_gpu(l, state);
- } else if(l.type == CRNN){
- backward_crnn_layer_gpu(l, state);
- } else if(l.type == COST){
- backward_cost_layer_gpu(l, state);
- } else if(l.type == ROUTE){
- backward_route_layer_gpu(l, net);
- } else if(l.type == SHORTCUT){
- backward_shortcut_layer_gpu(l, state);
- }
+ l.backward_gpu(l, state);
}
}
void update_network_gpu(network net)
{
+ cuda_set_device(net.gpu_index);
int i;
int update_batch = net.batch*net.subdivisions;
float rate = get_current_rate(net);
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
- if(l.type == CONVOLUTIONAL){
- update_convolutional_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
- } else if(l.type == DECONVOLUTIONAL){
- update_deconvolutional_layer_gpu(l, rate, net.momentum, net.decay);
- } else if(l.type == CONNECTED){
- update_connected_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
- } else if(l.type == GRU){
- update_gru_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
- } else if(l.type == RNN){
- update_rnn_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
- } else if(l.type == CRNN){
- update_crnn_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
- } else if(l.type == LOCAL){
- update_local_layer_gpu(l, update_batch, rate, net.momentum, net.decay);
+ l.t = get_current_batch(net);
+ if(l.update_gpu){
+ l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
}
}
}
@@ -203,7 +132,15 @@
state.delta = 0;
state.truth = *net.truth_gpu;
state.train = 1;
+#ifdef CUDNN_HALF
+ int i;
+ for (i = 0; i < net.n; ++i) {
+ layer l = net.layers[i];
+ cuda_convert_f32_to_f16(l.weights_gpu, l.c*l.n*l.size*l.size, l.weights_gpu16);
+ }
+#endif
forward_network_gpu(net, state);
+ //cudaStreamSynchronize(get_cuda_stream());
backward_network_gpu(net, state);
}
@@ -219,34 +156,32 @@
typedef struct {
network net;
- float *X;
- float *y;
+ data d;
+ float *err;
} train_args;
void *train_thread(void *ptr)
{
train_args args = *(train_args*)ptr;
-
- cuda_set_device(args.net.gpu_index);
- forward_backward_network_gpu(args.net, args.X, args.y);
free(ptr);
+ cuda_set_device(args.net.gpu_index);
+ *args.err = train_network(args.net, args.d);
return 0;
}
-pthread_t train_network_in_thread(network net, float *X, float *y)
+pthread_t train_network_in_thread(network net, data d, float *err)
{
pthread_t thread;
train_args *ptr = (train_args *)calloc(1, sizeof(train_args));
ptr->net = net;
- ptr->X = X;
- ptr->y = y;
+ ptr->d = d;
+ ptr->err = err;
if(pthread_create(&thread, 0, train_thread, ptr)) error("Thread creation failed");
return thread;
}
void pull_updates(layer l)
{
-#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
@@ -255,12 +190,10 @@
cuda_pull_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_pull_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
-#endif
}
void push_updates(layer l)
{
-#ifdef GPU
if(l.type == CONVOLUTIONAL){
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.n);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.n*l.size*l.size*l.c);
@@ -269,9 +202,84 @@
cuda_push_array(l.bias_updates_gpu, l.bias_updates, l.outputs);
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.outputs*l.inputs);
}
-#endif
}
+void update_layer(layer l, network net)
+{
+ int update_batch = net.batch*net.subdivisions;
+ float rate = get_current_rate(net);
+ l.t = get_current_batch(net);
+ if(l.update_gpu){
+ l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
+ }
+}
+
+void merge_weights(layer l, layer base)
+{
+ if (l.type == CONVOLUTIONAL) {
+ axpy_cpu(l.n, 1, l.biases, 1, base.biases, 1);
+ axpy_cpu(l.n*l.size*l.size*l.c, 1, l.weights, 1, base.weights, 1);
+ if (l.scales) {
+ axpy_cpu(l.n, 1, l.scales, 1, base.scales, 1);
+ }
+ } else if(l.type == CONNECTED) {
+ axpy_cpu(l.outputs, 1, l.biases, 1, base.biases, 1);
+ axpy_cpu(l.outputs*l.inputs, 1, l.weights, 1, base.weights, 1);
+ }
+}
+
+void scale_weights(layer l, float s)
+{
+ if (l.type == CONVOLUTIONAL) {
+ scal_cpu(l.n, s, l.biases, 1);
+ scal_cpu(l.n*l.size*l.size*l.c, s, l.weights, 1);
+ if (l.scales) {
+ scal_cpu(l.n, s, l.scales, 1);
+ }
+ } else if(l.type == CONNECTED) {
+ scal_cpu(l.outputs, s, l.biases, 1);
+ scal_cpu(l.outputs*l.inputs, s, l.weights, 1);
+ }
+}
+
+
+void pull_weights(layer l)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_pull_array(l.biases_gpu, l.biases, l.n);
+ cuda_pull_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c);
+ if(l.scales) cuda_pull_array(l.scales_gpu, l.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_pull_array(l.biases_gpu, l.biases, l.outputs);
+ cuda_pull_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
+ }
+}
+
+void push_weights(layer l)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.biases_gpu, l.biases, l.n);
+ cuda_push_array(l.weights_gpu, l.weights, l.n*l.size*l.size*l.c);
+ if(l.scales) cuda_push_array(l.scales_gpu, l.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.biases_gpu, l.biases, l.outputs);
+ cuda_push_array(l.weights_gpu, l.weights, l.outputs*l.inputs);
+ }
+}
+
+void distribute_weights(layer l, layer base)
+{
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.biases_gpu, base.biases, l.n);
+ cuda_push_array(l.weights_gpu, base.weights, l.n*l.size*l.size*l.c);
+ if(base.scales) cuda_push_array(l.scales_gpu, base.scales, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.biases_gpu, base.biases, l.outputs);
+ cuda_push_array(l.weights_gpu, base.weights, l.outputs*l.inputs);
+ }
+}
+
+
void merge_updates(layer l, layer base)
{
if (l.type == CONVOLUTIONAL) {
@@ -288,85 +296,119 @@
void distribute_updates(layer l, layer base)
{
- if (l.type == CONVOLUTIONAL) {
- copy_cpu(l.n, base.bias_updates, 1, l.bias_updates, 1);
- copy_cpu(l.n*l.size*l.size*l.c, base.weight_updates, 1, l.weight_updates, 1);
- if (l.scale_updates) {
- copy_cpu(l.n, base.scale_updates, 1, l.scale_updates, 1);
- }
- } else if(l.type == CONNECTED) {
- copy_cpu(l.outputs, base.bias_updates, 1, l.bias_updates, 1);
- copy_cpu(l.outputs*l.inputs, base.weight_updates, 1, l.weight_updates, 1);
+ if(l.type == CONVOLUTIONAL){
+ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.n);
+ cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.n*l.size*l.size*l.c);
+ if(base.scale_updates) cuda_push_array(l.scale_updates_gpu, base.scale_updates, l.n);
+ } else if(l.type == CONNECTED){
+ cuda_push_array(l.bias_updates_gpu, base.bias_updates, l.outputs);
+ cuda_push_array(l.weight_updates_gpu, base.weight_updates, l.outputs*l.inputs);
}
}
-void sync_updates(network *nets, int n)
+void sync_layer(network *nets, int n, int j)
{
- int i,j;
- int layers = nets[0].n;
+ //printf("Syncing layer %d\n", j);
+ int i;
network net = nets[0];
- for (j = 0; j < layers; ++j) {
- layer base = net.layers[j];
- cuda_set_device(net.gpu_index);
- pull_updates(base);
- for (i = 1; i < n; ++i) {
- cuda_set_device(nets[i].gpu_index);
- layer l = nets[i].layers[j];
- pull_updates(l);
- merge_updates(l, base);
- }
- for (i = 1; i < n; ++i) {
- cuda_set_device(nets[i].gpu_index);
- layer l = nets[i].layers[j];
- distribute_updates(l, base);
- push_updates(l);
- }
- cuda_set_device(net.gpu_index);
- push_updates(base);
+ layer base = net.layers[j];
+ cuda_set_device(net.gpu_index);
+ pull_weights(base);
+ for (i = 1; i < n; ++i) {
+ cuda_set_device(nets[i].gpu_index);
+ layer l = nets[i].layers[j];
+ pull_weights(l);
+ merge_weights(l, base);
}
+ scale_weights(base, 1./n);
for (i = 0; i < n; ++i) {
cuda_set_device(nets[i].gpu_index);
- if(i > 0) nets[i].momentum = 0;
- update_network_gpu(nets[i]);
+ layer l = nets[i].layers[j];
+ distribute_weights(l, base);
}
+ //printf("Done syncing layer %d\n", j);
}
-float train_networks(network *nets, int n, data d)
-{
- int batch = nets[0].batch;
- assert(batch * n == d.X.rows);
- assert(nets[0].subdivisions % n == 0);
- float **X = (float **) calloc(n, sizeof(float *));
- float **y = (float **) calloc(n, sizeof(float *));
- pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
+typedef struct{
+ network *nets;
+ int n;
+ int j;
+} sync_args;
+void *sync_layer_thread(void *ptr)
+{
+ sync_args args = *(sync_args*)ptr;
+ sync_layer(args.nets, args.n, args.j);
+ free(ptr);
+ return 0;
+}
+
+pthread_t sync_layer_in_thread(network *nets, int n, int j)
+{
+ pthread_t thread;
+ sync_args *ptr = (sync_args *)calloc(1, sizeof(sync_args));
+ ptr->nets = nets;
+ ptr->n = n;
+ ptr->j = j;
+ if(pthread_create(&thread, 0, sync_layer_thread, ptr)) error("Thread creation failed");
+ return thread;
+}
+
+void sync_nets(network *nets, int n, int interval)
+{
+ int j;
+ int layers = nets[0].n;
+ pthread_t *threads = (pthread_t *) calloc(layers, sizeof(pthread_t));
+
+ *nets[0].seen += interval * (n-1) * nets[0].batch * nets[0].subdivisions;
+ for (j = 0; j < n; ++j){
+ *nets[j].seen = *nets[0].seen;
+ }
+ for (j = 0; j < layers; ++j) {
+ threads[j] = sync_layer_in_thread(nets, n, j);
+ }
+ for (j = 0; j < layers; ++j) {
+ pthread_join(threads[j], 0);
+ }
+ free(threads);
+}
+
+float train_networks(network *nets, int n, data d, int interval)
+{
int i;
+ int batch = nets[0].batch;
+ int subdivisions = nets[0].subdivisions;
+ assert(batch * subdivisions * n == d.X.rows);
+ pthread_t *threads = (pthread_t *) calloc(n, sizeof(pthread_t));
+ float *errors = (float *) calloc(n, sizeof(float));
+
float sum = 0;
for(i = 0; i < n; ++i){
- X[i] = (float *) calloc(batch*d.X.cols, sizeof(float));
- y[i] = (float *) calloc(batch*d.y.cols, sizeof(float));
- get_next_batch(d, batch, i*batch, X[i], y[i]);
- threads[i] = train_network_in_thread(nets[i], X[i], y[i]);
+ data p = get_data_part(d, i, n);
+ threads[i] = train_network_in_thread(nets[i], p, errors + i);
}
for(i = 0; i < n; ++i){
pthread_join(threads[i], 0);
- *nets[i].seen += n*nets[i].batch;
- printf("%f\n", get_network_cost(nets[i]) / batch);
- sum += get_network_cost(nets[i]);
- free(X[i]);
- free(y[i]);
+ //printf("%f\n", errors[i]);
+ sum += errors[i];
}
- if (((*nets[0].seen) / nets[0].batch) % nets[0].subdivisions == 0) sync_updates(nets, n);
- free(X);
- free(y);
+ //cudaDeviceSynchronize();
+ if (get_current_batch(nets[0]) % interval == 0) {
+ printf("Syncing... ");
+ fflush(stdout);
+ sync_nets(nets, n, interval);
+ printf("Done!\n");
+ }
+ //cudaDeviceSynchronize();
free(threads);
- return (float)sum/(n*batch);
+ free(errors);
+ return (float)sum/(n);
}
float *get_network_output_layer_gpu(network net, int i)
{
layer l = net.layers[i];
- cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
+ if(l.type != REGION) cuda_pull_array(l.output_gpu, l.output, l.outputs*l.batch);
return l.output;
}
@@ -379,6 +421,8 @@
float *network_predict_gpu(network net, float *input)
{
+ if (net.gpu_index != cuda_get_device())
+ cuda_set_device(net.gpu_index);
int size = get_network_input_size(net) * net.batch;
network_state state;
state.index = 0;
--
Gitblit v1.10.0