From c7b10ceadb1a78e7480d281444a31ae2a7dc1b05 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 06 May 2016 23:25:16 +0000
Subject: [PATCH] so much need to commit
---
src/convolutional_kernels.cu | 209 +++++++++++++---------------------------------------
1 files changed, 53 insertions(+), 156 deletions(-)
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 3dc125f..62d6079 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -4,6 +4,7 @@
extern "C" {
#include "convolutional_layer.h"
+#include "batchnorm_layer.h"
#include "gemm.h"
#include "blas.h"
#include "im2col.h"
@@ -12,6 +13,41 @@
#include "cuda.h"
}
+__global__ void binarize_kernel(float *x, int n, float *binary)
+{
+ int i = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (i >= n) return;
+ binary[i] = (x[i] > 0) ? 1 : -1;
+}
+
+void binarize_gpu(float *x, int n, float *binary)
+{
+ binarize_kernel<<<cuda_gridsize(n), BLOCK>>>(x, n, binary);
+ check_error(cudaPeekAtLastError());
+}
+
+__global__ void binarize_input_kernel(float *input, int n, int size, float *binary)
+{
+ int s = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
+ if (s >= size) return;
+ int i = 0;
+ float mean = 0;
+ for(i = 0; i < n; ++i){
+ mean += abs(input[i*size + s]);
+ }
+ mean = mean / n;
+ for(i = 0; i < n; ++i){
+ binary[i*size + s] = (input[i*size + s] > 0) ? mean : -mean;
+ }
+}
+
+void binarize_input_gpu(float *input, int n, int size, float *binary)
+{
+ binarize_input_kernel<<<cuda_gridsize(size), BLOCK>>>(input, n, size, binary);
+ check_error(cudaPeekAtLastError());
+}
+
+
__global__ void binarize_filters_kernel(float *filters, int n, int size, float *binary)
{
int f = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -27,140 +63,12 @@
}
}
-__global__ void scale_bias_kernel(float *output, float *biases, int n, int size)
-{
- int offset = blockIdx.x * blockDim.x + threadIdx.x;
- int filter = blockIdx.y;
- int batch = blockIdx.z;
-
- if(offset < size) output[(batch*n+filter)*size + offset] *= biases[filter];
-}
-
-void scale_bias_gpu(float *output, float *biases, int batch, int n, int size)
-{
- dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
- dim3 dimBlock(BLOCK, 1, 1);
-
- scale_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
- check_error(cudaPeekAtLastError());
-}
-
-__global__ void backward_scale_kernel(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
- __shared__ float part[BLOCK];
- int i,b;
- int filter = blockIdx.x;
- int p = threadIdx.x;
- float sum = 0;
- for(b = 0; b < batch; ++b){
- for(i = 0; i < size; i += BLOCK){
- int index = p + i + size*(filter + n*b);
- sum += (p+i < size) ? delta[index]*x_norm[index] : 0;
- }
- }
- part[p] = sum;
- __syncthreads();
- if (p == 0) {
- for(i = 0; i < BLOCK; ++i) scale_updates[filter] += part[i];
- }
-}
-
void binarize_filters_gpu(float *filters, int n, int size, float *binary)
{
binarize_filters_kernel<<<cuda_gridsize(n), BLOCK>>>(filters, n, size, binary);
check_error(cudaPeekAtLastError());
}
-void backward_scale_gpu(float *x_norm, float *delta, int batch, int n, int size, float *scale_updates)
-{
- backward_scale_kernel<<<n, BLOCK>>>(x_norm, delta, batch, n, size, scale_updates);
- check_error(cudaPeekAtLastError());
-}
-
-__global__ void add_bias_kernel(float *output, float *biases, int n, int size)
-{
- int offset = blockIdx.x * blockDim.x + threadIdx.x;
- int filter = blockIdx.y;
- int batch = blockIdx.z;
-
- if(offset < size) output[(batch*n+filter)*size + offset] += biases[filter];
-}
-
-void add_bias_gpu(float *output, float *biases, int batch, int n, int size)
-{
- dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
- dim3 dimBlock(BLOCK, 1, 1);
-
- add_bias_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
- check_error(cudaPeekAtLastError());
-}
-
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
-{
- __shared__ float part[BLOCK];
- int i,b;
- int filter = blockIdx.x;
- int p = threadIdx.x;
- float sum = 0;
- for(b = 0; b < batch; ++b){
- for(i = 0; i < size; i += BLOCK){
- int index = p + i + size*(filter + n*b);
- sum += (p+i < size) ? delta[index] : 0;
- }
- }
- part[p] = sum;
- __syncthreads();
- if (p == 0) {
- for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
- }
-}
-
-__global__ void dot_kernel(float *output, float scale, int batch, int n, int size, float *delta)
-{
- int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
- int f1 = index / n;
- int f2 = index % n;
- if (f2 <= f1) return;
-
- float sum = 0;
- float norm1 = 0;
- float norm2 = 0;
- int b, i;
- for(b = 0; b < batch; ++b){
- for(i = 0; i < size; ++i){
- int i1 = b * size * n + f1 * size + i;
- int i2 = b * size * n + f2 * size + i;
- sum += output[i1] * output[i2];
- norm1 += output[i1] * output[i1];
- norm2 += output[i2] * output[i2];
- }
- }
- norm1 = sqrt(norm1);
- norm2 = sqrt(norm2);
- float norm = norm1 * norm2;
- sum = sum / norm;
- for(b = 0; b < batch; ++b){
- for(i = 0; i < size; ++i){
- int i1 = b * size * n + f1 * size + i;
- int i2 = b * size * n + f2 * size + i;
- delta[i1] += - scale * sum * output[i2] / norm;
- delta[i2] += - scale * sum * output[i1] / norm;
- }
- }
-}
-
-void dot_error_gpu(layer l)
-{
- dot_kernel<<<cuda_gridsize(l.n*l.n), BLOCK>>>(l.output_gpu, l.dot, l.batch, l.n, l.out_w * l.out_h, l.delta_gpu);
- check_error(cudaPeekAtLastError());
-}
-
-void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
-{
- backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
- check_error(cudaPeekAtLastError());
-}
-
void forward_convolutional_layer_gpu(convolutional_layer l, network_state state)
{
int i;
@@ -175,6 +83,16 @@
swap_binary(&l);
}
+ if(l.xnor){
+ binarize_filters_gpu(l.filters_gpu, l.n, l.c*l.size*l.size, l.binary_filters_gpu);
+ //binarize_gpu(l.filters_gpu, l.n*l.c*l.size*l.size, l.binary_filters_gpu);
+ swap_binary(&l);
+ for(i = 0; i < l.batch; ++i){
+ binarize_input_gpu(state.input + i*l.inputs, l.c, l.h*l.w, l.binary_input_gpu + i*l.inputs);
+ }
+ state.input = l.binary_input_gpu;
+ }
+
for(i = 0; i < l.batch; ++i){
im2col_ongpu(state.input + i*l.c*l.h*l.w, l.c, l.h, l.w, l.size, l.stride, l.pad, l.col_image_gpu);
float * a = l.filters_gpu;
@@ -184,29 +102,13 @@
}
if (l.batch_normalize) {
- if (state.train) {
- fast_mean_gpu(l.output_gpu, l.batch, l.n, l.out_h*l.out_w, l.mean_gpu);
- fast_variance_gpu(l.output_gpu, l.mean_gpu, l.batch, l.n, l.out_h*l.out_w, l.variance_gpu);
-
- scal_ongpu(l.n, .95, l.rolling_mean_gpu, 1);
- axpy_ongpu(l.n, .05, l.mean_gpu, 1, l.rolling_mean_gpu, 1);
- scal_ongpu(l.n, .95, l.rolling_variance_gpu, 1);
- axpy_ongpu(l.n, .05, l.variance_gpu, 1, l.rolling_variance_gpu, 1);
-
- copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_gpu, 1);
- normalize_gpu(l.output_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_h*l.out_w);
- copy_ongpu(l.outputs*l.batch, l.output_gpu, 1, l.x_norm_gpu, 1);
- } else {
- normalize_gpu(l.output_gpu, l.rolling_mean_gpu, l.rolling_variance_gpu, l.batch, l.n, l.out_h*l.out_w);
- }
-
- scale_bias_gpu(l.output_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
+ forward_batchnorm_layer_gpu(l, state);
}
add_bias_gpu(l.output_gpu, l.biases_gpu, l.batch, l.n, n);
activate_array_ongpu(l.output_gpu, m*n*l.batch, l.activation);
- if(l.dot > 0) dot_error_gpu(l);
- if(l.binary) swap_binary(&l);
+ //if(l.dot > 0) dot_error_gpu(l);
+ if(l.binary || l.xnor) swap_binary(&l);
}
void backward_convolutional_layer_gpu(convolutional_layer l, network_state state)
@@ -222,15 +124,10 @@
backward_bias_gpu(l.bias_updates_gpu, l.delta_gpu, l.batch, l.n, k);
if(l.batch_normalize){
- backward_scale_gpu(l.x_norm_gpu, l.delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.scale_updates_gpu);
-
- scale_bias_gpu(l.delta_gpu, l.scales_gpu, l.batch, l.n, l.out_h*l.out_w);
-
- fast_mean_delta_gpu(l.delta_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.mean_delta_gpu);
- fast_variance_delta_gpu(l.x_gpu, l.delta_gpu, l.mean_gpu, l.variance_gpu, l.batch, l.n, l.out_w*l.out_h, l.variance_delta_gpu);
- normalize_delta_gpu(l.x_gpu, l.mean_gpu, l.variance_gpu, l.mean_delta_gpu, l.variance_delta_gpu, l.batch, l.n, l.out_w*l.out_h, l.delta_gpu);
+ backward_batchnorm_layer_gpu(l, state);
}
+ if(l.xnor) state.input = l.binary_input_gpu;
for(i = 0; i < l.batch; ++i){
float * a = l.delta_gpu;
float * b = l.col_image_gpu;
@@ -240,7 +137,7 @@
gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
if(state.delta){
- if(l.binary) swap_binary(&l);
+ if(l.binary || l.xnor) swap_binary(&l);
float * a = l.filters_gpu;
float * b = l.delta_gpu;
float * c = l.col_image_gpu;
@@ -248,7 +145,7 @@
gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
col2im_ongpu(l.col_image_gpu, l.c, l.h, l.w, l.size, l.stride, l.pad, state.delta + i*l.c*l.h*l.w);
- if(l.binary) swap_binary(&l);
+ if(l.binary || l.xnor) swap_binary(&l);
}
}
}
--
Gitblit v1.10.0