From db0397cfaaf488364e3d2e1669dfefae2ee6ea73 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 14 Dec 2015 19:57:10 +0000
Subject: [PATCH] shortcut layers, msr networks
---
src/convolutional_kernels.cu | 205 ++++++++++++++++++++++++++++++++++++--------------
1 files changed, 146 insertions(+), 59 deletions(-)
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index bcf307f..a64a499 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -1,3 +1,7 @@
+#include "cuda_runtime.h"
+#include "curand.h"
+#include "cublas_v2.h"
+
extern "C" {
#include "convolutional_layer.h"
#include "gemm.h"
@@ -8,25 +12,69 @@
#include "cuda.h"
}
-__global__ void bias_output_kernel(float *output, float *biases, int n, int size)
+__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
{
int offset = blockIdx.x * blockDim.x + threadIdx.x;
int filter = blockIdx.y;
int batch = blockIdx.z;
- if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
+ if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
}
-extern "C" void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
+void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
{
- dim3 dimBlock(BLOCK, 1, 1);
dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+ dim3 dimBlock(BLOCK, 1, 1);
- bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+ scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
check_error(cudaPeekAtLastError());
}
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size, float scale)
+__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+ __shared__ float part[BLOCK];
+ int i,b;
+ int filter = blockIdx.x;
+ int p = threadIdx.x;
+ float sum = 0;
+ for(b = 0; b < batch; ++b){
+ for(i = 0; i < size; i += BLOCK){
+ int index = p + i + size*(filter + n*b);
+ sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
+ }
+ }
+ part[p] = sum;
+ __syncthreads();
+ if (p == 0) {
+ for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
+ }
+}
+
+void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
+{
+ backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
+{
+ int offset = blockIdx.x * blockDim.x + threadIdx.x;
+ int filter = blockIdx.y;
+ int batch = blockIdx.z;
+
+ if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
+}
+
+void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
+{
+ dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+ dim3 dimBlock(BLOCK, 1, 1);
+
+ add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
{
__shared__ float part[BLOCK];
int i,b;
@@ -41,106 +89,145 @@
}
part[p] = sum;
__syncthreads();
- if(p == 0){
- for(i = 0; i < BLOCK; ++i) bias_updates[filter] += scale * part[i];
+ if (p == 0) {
+ for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
}
}
-extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
- float alpha = 1./batch;
-
- backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, alpha);
+ backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
check_error(cudaPeekAtLastError());
}
-extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float *in)
+void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
int i;
- int m = layer.n;
- int k = layer.size*layer.size*layer.c;
- int n = convolutional_out_height(layer)*
- convolutional_out_width(layer);
+ int m = l.n;
+ int k = l.size*l.size*l.c;
+ int n = convolutional_out_height(l)*
+ convolutional_out_width(l);
- bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
-
- for(i = 0; i < layer.batch; ++i){
- im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
- float * a = layer.filters_gpu;
- float * b = layer.col_image_gpu;
- float * c = layer.output_gpu;
+ fill_ongpu(l.outputs*l.batch, 0, l.output_gpu, 1);
+ for(i = 0; i < l.batch; ++i){
+ im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
+ float * a = l.filters_gpu;
+ float * b = l.col_image_gpu;
+ float * c = l.output_gpu;
gemm_ongpu(0,0,m,n,k,1.,a,k,b,n,1.,c+i*m*n,n);
}
- activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
+
+ if(l.batch_normalize){
+ if(state.train){
+ fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.mean_gpu);
+ fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.variance_gpu);
+
+ /*
+ cuda_pull_array(l.variance_gpu, l.mean, 1);
+ printf("%f\n", l.mean[0]);
+ */
+
+
+ scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
+ axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
+ scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
+ axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
+
+ copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
+ normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
+ copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
+ } else {
+ normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
+ }
+
+ scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
+ }
+ add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
+
+ activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
}
-extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, float *in, float *delta_gpu)
+void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
- float alpha = 1./layer.batch;
int i;
- int m = layer.n;
- int n = layer.size*layer.size*layer.c;
- int k = convolutional_out_height(layer)*
- convolutional_out_width(layer);
+ int m = l.n;
+ int n = l.size*l.size*l.c;
+ int k = convolutional_out_height(l)*
+ convolutional_out_width(l);
- gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
- backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
+ gradient_array_ongpu(l.output_gpu, m*k*l.batch, l.activation, l.delta_gpu);
- if(delta_gpu) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, delta_gpu, 1);
+ backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
- for(i = 0; i < layer.batch; ++i){
- float * a = layer.delta_gpu;
- float * b = layer.col_image_gpu;
- float * c = layer.filter_updates_gpu;
+ if(l.batch_normalize){
+ backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
- im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, layer.col_image_gpu);
- gemm_ongpu(0,1,m,n,k,alpha,a + i*m*k,k,b,k,1,c,n);
+ scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
- if(delta_gpu){
+ fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.mean_delta_gpu);
+ fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.variance_delta_gpu);
+ normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
+ }
- float * a = layer.filters_gpu;
- float * b = layer.delta_gpu;
- float * c = layer.col_image_gpu;
+ for(i = 0; i < l.batch; ++i){
+ float * a = l.delta_gpu;
+ float * b = l.col_image_gpu;
+ float * c = l.filter_updates_gpu;
+
+ im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
+ gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
+
+ if(state.delta){
+ float * a = l.filters_gpu;
+ float * b = l.delta_gpu;
+ float * c = l.col_image_gpu;
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
- col2im_ongpu(layer.col_image_gpu, layer.c, layer.h, layer.w, layer.size, layer.stride, layer.pad, delta_gpu + i*layer.c*layer.h*layer.w);
+ col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
}
}
}
-extern "C" void pull_convolutional_layer(convolutional_layer layer)
+void pull_convolutional_layer(convolutional_layer layer)
{
cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
+ if (layer.batch_normalize){
+ cuda_pull_array(layer.scales_gpu, layer.scales, layer.n);
+ cuda_pull_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
+ cuda_pull_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
+ }
}
-extern "C" void push_convolutional_layer(convolutional_layer layer)
+void push_convolutional_layer(convolutional_layer layer)
{
cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
cuda_push_array(layer.filter_updates_gpu, layer.filter_updates, layer.c*layer.n*layer.size*layer.size);
cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
+ if (layer.batch_normalize){
+ cuda_push_array(layer.scales_gpu, layer.scales, layer.n);
+ cuda_push_array(layer.rolling_mean_gpu, layer.rolling_mean, layer.n);
+ cuda_push_array(layer.rolling_variance_gpu, layer.rolling_variance, layer.n);
+ }
}
-extern "C" void update_convolutional_layer_gpu(convolutional_layer layer)
+void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
int size = layer.size*layer.size*layer.c*layer.n;
-/*
- cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, size);
- cuda_pull_array(layer.filters_gpu, layer.filters, size);
- printf("Filter: %f updates: %f\n", mag_array(layer.filters, size), layer.learning_rate*mag_array(layer.filter_updates, size));
- */
+ axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
+ scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
- axpy_ongpu(layer.n, layer.learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
- scal_ongpu(layer.n,layer.momentum, layer.bias_updates_gpu, 1);
+ axpy_ongpu(layer.n, learning_rate/batch, layer.scale_updates_gpu, 1, layer.scales_gpu, 1);
+ scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
- axpy_ongpu(size, -layer.decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
- axpy_ongpu(size, layer.learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
- scal_ongpu(size, layer.momentum, layer.filter_updates_gpu, 1);
- //pull_convolutional_layer(layer);
+ axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
+ axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
+ scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
}
+
--
Gitblit v1.10.0