From 989ab8c38a02fa7ea9c25108151736c62e81c972 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 24 Apr 2015 17:27:50 +0000
Subject: [PATCH] IOU loss function

---
 src/convolutional_kernels.cu |   55 ++++++++++++++++++++++---------------------------------
 1 files changed, 22 insertions(+), 33 deletions(-)

diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index bcf307f..d260a95 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -17,16 +17,16 @@
     if(offset < size) output[(batch*n+filter)*size + offset] = biases[filter];
 }
 
-extern "C" void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
+void bias_output_gpu(float *output, float *biases, int batch, int n, int size)
 {
-    dim3 dimBlock(BLOCK, 1, 1);
     dim3 dimGrid((size-1)/BLOCK + 1, n, batch);
+    dim3 dimBlock(BLOCK, 1, 1);
 
     bias_output_kernel<<<dimGrid, dimBlock>>>(output, biases, n, size);
     check_error(cudaPeekAtLastError());
 }
 
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size, float scale)
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
 {
     __shared__ float part[BLOCK];
     int i,b;
@@ -42,19 +42,17 @@
     part[p] = sum;
     __syncthreads();
     if(p == 0){
-        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += scale * part[i];
+        for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
     }
 }
 
-extern "C" void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
+void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
 {
-    float alpha = 1./batch;
-
-    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, alpha);
+    backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
     check_error(cudaPeekAtLastError());
 }
 
-extern "C" void forward_convolutional_layer_gpu(convolutional_layer layer, float *in)
+void forward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
 {
     int i;
     int m = layer.n;
@@ -63,9 +61,8 @@
         convolutional_out_width(layer);
 
     bias_output_gpu(layer.output_gpu, layer.biases_gpu, layer.batch, layer.n, n);
-
     for(i = 0; i < layer.batch; ++i){
-        im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
+        im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
         float * a = layer.filters_gpu;
         float * b = layer.col_image_gpu;
         float * c = layer.output_gpu;
@@ -74,9 +71,8 @@
     activate_array_ongpu(layer.output_gpu, m*n*layer.batch, layer.activation);
 }
 
-extern "C" void backward_convolutional_layer_gpu(convolutional_layer layer, float *in, float *delta_gpu)
+void backward_convolutional_layer_gpu(convolutional_layer layer, network_state state)
 {
-    float alpha = 1./layer.batch;
     int i;
     int m = layer.n;
     int n = layer.size*layer.size*layer.c;
@@ -86,17 +82,17 @@
     gradient_array_ongpu(layer.output_gpu, m*k*layer.batch, layer.activation, layer.delta_gpu);
     backward_bias_gpu(layer.bias_updates_gpu, layer.delta_gpu, layer.batch, layer.n, k);
 
-    if(delta_gpu) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, delta_gpu, 1);
+    if(state.delta) scal_ongpu(layer.batch*layer.h*layer.w*layer.c, 0, state.delta, 1);
 
     for(i = 0; i < layer.batch; ++i){
         float * a = layer.delta_gpu;
         float * b = layer.col_image_gpu;
         float * c = layer.filter_updates_gpu;
 
-        im2col_ongpu(in + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
-        gemm_ongpu(0,1,m,n,k,alpha,a + i*m*k,k,b,k,1,c,n);
+        im2col_ongpu(state.input + i*layer.c*layer.h*layer.w, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, layer.col_image_gpu);
+        gemm_ongpu(0,1,m,n,k,1,a + i*m*k,k,b,k,1,c,n);
 
-        if(delta_gpu){
+        if(state.delta){
 
             float * a = layer.filters_gpu;
             float * b = layer.delta_gpu;
@@ -104,12 +100,12 @@
 
             gemm_ongpu(1,0,n,k,m,1,a,n,b + i*k*m,k,0,c,k);
 
-            col2im_ongpu(layer.col_image_gpu, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, delta_gpu + i*layer.c*layer.h*layer.w);
+            col2im_ongpu(layer.col_image_gpu, layer.c,  layer.h,  layer.w,  layer.size,  layer.stride, layer.pad, state.delta + i*layer.c*layer.h*layer.w);
         }
     }
 }
 
-extern "C" void pull_convolutional_layer(convolutional_layer layer)
+void pull_convolutional_layer(convolutional_layer layer)
 {
     cuda_pull_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
     cuda_pull_array(layer.biases_gpu, layer.biases, layer.n);
@@ -117,7 +113,7 @@
     cuda_pull_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
 }
 
-extern "C" void push_convolutional_layer(convolutional_layer layer)
+void push_convolutional_layer(convolutional_layer layer)
 {
     cuda_push_array(layer.filters_gpu, layer.filters, layer.c*layer.n*layer.size*layer.size);
     cuda_push_array(layer.biases_gpu, layer.biases, layer.n);
@@ -125,22 +121,15 @@
     cuda_push_array(layer.bias_updates_gpu, layer.bias_updates, layer.n);
 }
 
-extern "C" void update_convolutional_layer_gpu(convolutional_layer layer)
+void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
 {
     int size = layer.size*layer.size*layer.c*layer.n;
 
-/*
-    cuda_pull_array(layer.filter_updates_gpu, layer.filter_updates, size);
-    cuda_pull_array(layer.filters_gpu, layer.filters, size);
-    printf("Filter: %f updates: %f\n", mag_array(layer.filters, size), layer.learning_rate*mag_array(layer.filter_updates, size));
-    */
+    axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
+    scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
 
-    axpy_ongpu(layer.n, layer.learning_rate, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
-    scal_ongpu(layer.n,layer.momentum, layer.bias_updates_gpu, 1);
-
-    axpy_ongpu(size, -layer.decay, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
-    axpy_ongpu(size, layer.learning_rate, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
-    scal_ongpu(size, layer.momentum, layer.filter_updates_gpu, 1);
-    //pull_convolutional_layer(layer);
+    axpy_ongpu(size, -decay*batch, layer.filters_gpu, 1, layer.filter_updates_gpu, 1);
+    axpy_ongpu(size, learning_rate/batch, layer.filter_updates_gpu, 1, layer.filters_gpu, 1);
+    scal_ongpu(size, momentum, layer.filter_updates_gpu, 1);
 }
 

--
Gitblit v1.10.0