From d7d7da2653ff4f79a275529b0ac3fec438880083 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 27 Mar 2015 02:13:59 +0000
Subject: [PATCH] Fixed im2col mistake >< face#palm
---
src/dropout_layer_kernels.cu | 5 +++++
src/col2im_kernels.cu | 6 +++---
src/detection.c | 9 +++++----
Makefile | 2 +-
src/network_kernels.cu | 1 +
src/convolutional_kernels.cu | 6 +++---
src/detection_layer.c | 18 ++++++++++++++++--
7 files changed, 34 insertions(+), 13 deletions(-)
diff --git a/Makefile b/Makefile
index bdbc73d..f474fce 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
GPU=1
DEBUG=0
-ARCH= -arch=sm_35
+ARCH= -arch=sm_50
VPATH=./src/
EXEC=darknet
diff --git a/src/col2im_kernels.cu b/src/col2im_kernels.cu
index 76a86e6..67c0b03 100644
--- a/src/col2im_kernels.cu
+++ b/src/col2im_kernels.cu
@@ -37,9 +37,9 @@
}
}
-void col2im_ongpu(float *im,
+void col2im_ongpu(float *data_col,
int channels, int height, int width,
- int ksize, int stride, int pad, float *data_col){
+ int ksize, int stride, int pad, float *data_im){
// We are going to launch channels * height_col * width_col kernels, each
// kernel responsible for copying a single-channel grid.
pad = pad ? ksize/2 : 0;
@@ -50,7 +50,7 @@
BLOCK>>>(
num_kernels, data_col, height, width, ksize, pad,
stride, height_col,
- width_col, im);
+ width_col, data_im);
}
/*
diff --git a/src/convolutional_kernels.cu b/src/convolutional_kernels.cu
index 5b49091..c28731f 100644
--- a/src/convolutional_kernels.cu
+++ b/src/convolutional_kernels.cu
@@ -26,7 +26,7 @@
check_error(cudaPeekAtLastError());
}
-__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size, float scale)
+__global__ void backward_bias_kernel(float *bias_updates, float *delta, int batch, int n, int size)
{
__shared__ float part[BLOCK];
int i,b;
@@ -42,13 +42,13 @@
part[p] = sum;
__syncthreads();
if(p == 0){
- for(i = 0; i < BLOCK; ++i) bias_updates[filter] += scale * part[i];
+ for(i = 0; i < BLOCK; ++i) bias_updates[filter] += part[i];
}
}
void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int size)
{
- backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size, 1);
+ backward_bias_kernel<<<n, BLOCK>>>(bias_updates, delta, batch, n, size);
check_error(cudaPeekAtLastError());
}
diff --git a/src/detection.c b/src/detection.c
index 1800ca6..69202aa 100644
--- a/src/detection.c
+++ b/src/detection.c
@@ -45,7 +45,7 @@
{
char *base = basecfg(cfgfile);
printf("%s\n", base);
- float avg_loss = 1;
+ float avg_loss = -1;
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
@@ -84,6 +84,7 @@
time=clock();
float loss = train_network(net, train);
net.seen += imgs;
+ if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
if(i%100==0){
@@ -109,8 +110,8 @@
char **paths = (char **)list_to_array(plist);
int im_size = 448;
int classes = 20;
- int background = 1;
- int nuisance = 0;
+ int background = 0;
+ int nuisance = 1;
int num_output = 7*7*(4+classes+background+nuisance);
int m = plist->size;
@@ -137,7 +138,7 @@
for(j = 0; j < pred.rows; ++j){
for(k = 0; k < pred.cols; k += classes+4+background+nuisance){
float scale = 1.;
- if(nuisance) scale = pred.vals[j][k];
+ if(nuisance) scale = 1.-pred.vals[j][k];
for(class = 0; class < classes; ++class){
int index = (k)/(classes+4+background+nuisance);
int r = index/7;
diff --git a/src/detection_layer.c b/src/detection_layer.c
index 27a4daf..73b2862 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -93,6 +93,19 @@
}
}
/*
+ int count = 0;
+ for(i = 0; i < layer.batch*locations; ++i){
+ for(j = 0; j < layer.classes+layer.background; ++j){
+ printf("%f, ", layer.output[count++]);
+ }
+ printf("\n");
+ for(j = 0; j < layer.coords; ++j){
+ printf("%f, ", layer.output[count++]);
+ }
+ printf("\n");
+ }
+ */
+ /*
if(layer.background || 1){
for(i = 0; i < layer.batch*locations; ++i){
int index = i*(layer.classes+layer.coords+layer.background);
@@ -123,8 +136,9 @@
state.delta[in_i++] = scale*layer.delta[out_i++];
}
- if (layer.nuisance) ;
- else if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
+ if (layer.nuisance) {
+
+ }else if (layer.background) gradient_array(layer.output + out_i, layer.coords, LOGISTIC, layer.delta + out_i);
for(j = 0; j < layer.coords; ++j){
state.delta[in_i++] = layer.delta[out_i++];
}
diff --git a/src/dropout_layer_kernels.cu b/src/dropout_layer_kernels.cu
index 2638ac5..231497d 100644
--- a/src/dropout_layer_kernels.cu
+++ b/src/dropout_layer_kernels.cu
@@ -16,6 +16,11 @@
if (!state.train) return;
int size = layer.inputs*layer.batch;
cuda_random(layer.rand_gpu, size);
+ int i;
+ for(i = 0; i < size; ++i){
+ layer.rand[i] = rand_uniform();
+ }
+ cuda_push_array(layer.rand_gpu, layer.rand, size);
yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
check_error(cudaPeekAtLastError());
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 019f40d..0b2bb97 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -71,6 +71,7 @@
state.input = get_network_output_gpu_layer(net, i-1);
state.delta = get_network_delta_gpu_layer(net, i-1);
}
+
if(net.types[i] == CONVOLUTIONAL){
backward_convolutional_layer_gpu(*(convolutional_layer *)net.layers[i], state);
}
--
Gitblit v1.10.0