From f199fd3b6464e644566d76676c0b5f1824d26c4e Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 17 Apr 2015 19:32:54 +0000
Subject: [PATCH] per image randomness in crop layer
---
src/image.c | 6 +-
src/crop_layer.c | 5 ++
src/detection.c | 14 +++---
Makefile | 2
src/crop_layer.h | 5 ++
src/parser.c | 4 +
src/crop_layer_kernels.cu | 79 ++++++++++++++++++++++++---------------
src/darknet.c | 2
8 files changed, 72 insertions(+), 45 deletions(-)
diff --git a/Makefile b/Makefile
index 0e185f7..63cb621 100644
--- a/Makefile
+++ b/Makefile
@@ -1,4 +1,4 @@
-GPU=0
+GPU=1
DEBUG=0
ARCH= -arch=sm_52
diff --git a/src/crop_layer.c b/src/crop_layer.c
index e83aea2..7ae4aa5 100644
--- a/src/crop_layer.c
+++ b/src/crop_layer.c
@@ -10,7 +10,7 @@
return float_to_image(w,h,c,layer.output);
}
-crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle)
+crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure)
{
fprintf(stderr, "Crop Layer: %d x %d -> %d x %d x %d image\n", h,w,crop_height,crop_width,c);
crop_layer *layer = calloc(1, sizeof(crop_layer));
@@ -20,11 +20,14 @@
layer->c = c;
layer->flip = flip;
layer->angle = angle;
+ layer->saturation = saturation;
+ layer->exposure = exposure;
layer->crop_width = crop_width;
layer->crop_height = crop_height;
layer->output = calloc(crop_width*crop_height * c*batch, sizeof(float));
#ifdef GPU
layer->output_gpu = cuda_make_array(layer->output, crop_width*crop_height*c*batch);
+ layer->rand_gpu = cuda_make_array(0, layer->batch*8);
#endif
return layer;
}
diff --git a/src/crop_layer.h b/src/crop_layer.h
index a320f0e..0033339 100644
--- a/src/crop_layer.h
+++ b/src/crop_layer.h
@@ -11,14 +11,17 @@
int crop_height;
int flip;
float angle;
+ float saturation;
+ float exposure;
float *output;
#ifdef GPU
float *output_gpu;
+ float *rand_gpu;
#endif
} crop_layer;
image get_crop_image(crop_layer layer);
-crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle);
+crop_layer *make_crop_layer(int batch, int h, int w, int c, int crop_height, int crop_width, int flip, float angle, float saturation, float exposure);
void forward_crop_layer(const crop_layer layer, network_state state);
#ifdef GPU
diff --git a/src/crop_layer_kernels.cu b/src/crop_layer_kernels.cu
index 3e7ee95..98d1ef4 100644
--- a/src/crop_layer_kernels.cu
+++ b/src/crop_layer_kernels.cu
@@ -93,7 +93,7 @@
return val;
}
-__global__ void levels_image_kernel(float *image, int batch, int w, int h, float saturation, float exposure, float translate, float scale)
+__global__ void levels_image_kernel(float *image, float *rand, int batch, int w, int h, int train, float saturation, float exposure, float translate, float scale)
{
int size = batch * w * h;
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -102,22 +102,34 @@
id /= w;
int y = id % h;
id /= h;
+ float r0 = rand[8*id + 0];
+ float r1 = rand[8*id + 1];
+ float r2 = rand[8*id + 2];
+ float r3 = rand[8*id + 3];
+
+ saturation = r0*(saturation - 1) + 1;
+ saturation = (r1 > .5) ? 1./saturation : saturation;
+ exposure = r2*(exposure - 1) + 1;
+ exposure = (r3 > .5) ? 1./exposure : exposure;
+
size_t offset = id * h * w * 3;
image += offset;
float r = image[x + w*(y + h*2)];
float g = image[x + w*(y + h*1)];
float b = image[x + w*(y + h*0)];
float3 rgb = make_float3(r,g,b);
- float3 hsv = rgb_to_hsv_kernel(rgb);
- hsv.y *= saturation;
- hsv.z *= exposure;
- rgb = hsv_to_rgb_kernel(hsv);
+ if(train){
+ float3 hsv = rgb_to_hsv_kernel(rgb);
+ hsv.y *= saturation;
+ hsv.z *= exposure;
+ rgb = hsv_to_rgb_kernel(hsv);
+ }
image[x + w*(y + h*2)] = rgb.x*scale + translate;
image[x + w*(y + h*1)] = rgb.y*scale + translate;
image[x + w*(y + h*0)] = rgb.z*scale + translate;
}
-__global__ void forward_crop_layer_kernel(float *input, int size, int c, int h, int w, int crop_height, int crop_width, int dh, int dw, int flip, float angle, float *output)
+__global__ void forward_crop_layer_kernel(float *input, float *rand, int size, int c, int h, int w, int crop_height, int crop_width, int train, int flip, float angle, float *output)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if(id >= size) return;
@@ -134,10 +146,26 @@
id /= c;
int b = id;
+ float r4 = rand[8*b + 4];
+ float r5 = rand[8*b + 5];
+ float r6 = rand[8*b + 6];
+ float r7 = rand[8*b + 7];
+
+ float dw = (w - crop_width)*r4;
+ float dh = (h - crop_height)*r5;
+ flip = (flip && (r6 > .5));
+ angle = 2*angle*r7 - angle;
+ if(!train){
+ dw = (w - crop_width)/2.;
+ dh = (h - crop_height)/2.;
+ flip = 0;
+ angle = 0;
+ }
+
input += w*h*c*b;
- int x = (flip) ? w - dw - j - 1 : j + dw;
- int y = i + dh;
+ float x = (flip) ? w - dw - j - 1 : j + dw;
+ float y = i + dh;
float rx = cos(angle)*(x-cx) - sin(angle)*(y-cy) + cx;
float ry = sin(angle)*(x-cx) + cos(angle)*(y-cy) + cy;
@@ -147,38 +175,21 @@
extern "C" void forward_crop_layer_gpu(crop_layer layer, network_state state)
{
- int flip = (layer.flip && rand()%2);
- int dh = rand()%(layer.h - layer.crop_height + 1);
- int dw = rand()%(layer.w - layer.crop_width + 1);
- float radians = layer.angle*3.14159/180.;
- float angle = 2*radians*rand_uniform() - radians;
+ cuda_random(layer.rand_gpu, layer.batch*8);
- float saturation = rand_uniform() + 1;
- if(rand_uniform() > .5) saturation = 1./saturation;
- float exposure = rand_uniform() + 1;
- if(rand_uniform() > .5) exposure = 1./exposure;
+ float radians = layer.angle*3.14159/180.;
float scale = 2;
float translate = -1;
- if(!state.train){
- angle = 0;
- flip = 0;
- dh = (layer.h - layer.crop_height)/2;
- dw = (layer.w - layer.crop_width)/2;
- saturation = 1;
- exposure = 1;
- }
-
int size = layer.batch * layer.w * layer.h;
- levels_image_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, layer.batch, layer.w, layer.h, saturation, exposure, translate, scale);
+ levels_image_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, layer.rand_gpu, layer.batch, layer.w, layer.h, state.train, layer.saturation, layer.exposure, translate, scale);
check_error(cudaPeekAtLastError());
-
+
size = layer.batch*layer.c*layer.crop_width*layer.crop_height;
- forward_crop_layer_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.c, layer.h, layer.w,
- layer.crop_height, layer.crop_width, dh, dw, flip, angle, layer.output_gpu);
+ forward_crop_layer_kernel<<<cuda_gridsize(size), BLOCK>>>(state.input, layer.rand_gpu, size, layer.c, layer.h, layer.w, layer.crop_height, layer.crop_width, state.train, layer.flip, radians, layer.output_gpu);
check_error(cudaPeekAtLastError());
/*
@@ -186,6 +197,14 @@
image im = float_to_image(layer.crop_width, layer.crop_height, layer.c, layer.output + 0*(size/layer.batch));
image im2 = float_to_image(layer.crop_width, layer.crop_height, layer.c, layer.output + 1*(size/layer.batch));
image im3 = float_to_image(layer.crop_width, layer.crop_height, layer.c, layer.output + 2*(size/layer.batch));
+
+ translate_image(im, -translate);
+ scale_image(im, 1/scale);
+ translate_image(im2, -translate);
+ scale_image(im2, 1/scale);
+ translate_image(im3, -translate);
+ scale_image(im3, 1/scale);
+
show_image(im, "cropped");
show_image(im2, "cropped2");
show_image(im3, "cropped3");
diff --git a/src/darknet.c b/src/darknet.c
index cca5473..46a8c82 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -68,7 +68,7 @@
if(weightfile){
load_weights_upto(&net, weightfile, max);
}
- //net.seen = 0;
+ net.seen = 0;
save_weights(net, outfile);
}
diff --git a/src/detection.c b/src/detection.c
index 024c0e9..cba3d18 100644
--- a/src/detection.c
+++ b/src/detection.c
@@ -82,6 +82,8 @@
plist = get_paths("/home/pjreddie/data/imagenet/det.train.list");
}else{
plist = get_paths("/home/pjreddie/data/voc/trainall.txt");
+ //plist = get_paths("/home/pjreddie/data/coco/trainval.txt");
+ //plist = get_paths("/home/pjreddie/data/voc/all2007-2012.txt");
}
paths = (char **)list_to_array(plist);
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
@@ -94,13 +96,11 @@
load_thread = load_data_detection_thread(imgs, paths, plist->size, classes, net.w, net.h, side, side, background, &buffer);
/*
- image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
- image copy = copy_image(im);
- translate_image(copy, 1);
- scale_image(copy, .5);
- draw_detection(copy, train.y.vals[114], 7);
- free_image(copy);
- */
+ image im = float_to_image(net.w, net.h, 3, train.X.vals[114]);
+ image copy = copy_image(im);
+ draw_detection(copy, train.y.vals[114], 7);
+ free_image(copy);
+ */
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
diff --git a/src/image.c b/src/image.c
index bf6ce6a..1daea27 100644
--- a/src/image.c
+++ b/src/image.c
@@ -182,8 +182,8 @@
}
}
free_image(copy);
- if(disp->height < 500 || disp->width < 500 || disp->height > 1000){
- int w = 500;
+ if(disp->height < 448 || disp->width < 448 || disp->height > 1000){
+ int w = 448;
int h = w*p.h/p.w;
if(h > 1000){
h = 1000;
@@ -191,7 +191,7 @@
}
IplImage *buffer = disp;
disp = cvCreateImage(cvSize(w, h), buffer->depth, buffer->nChannels);
- cvResize(buffer, disp, CV_INTER_NN);
+ cvResize(buffer, disp, CV_INTER_LINEAR);
cvReleaseImage(&buffer);
}
cvShowImage(buff, disp);
diff --git a/src/parser.c b/src/parser.c
index 08e0ea1..ca60ef7 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -187,6 +187,8 @@
int crop_width = option_find_int(options, "crop_width",1);
int flip = option_find_int(options, "flip",0);
float angle = option_find_float(options, "angle",0);
+ float saturation = option_find_float(options, "saturation",1);
+ float exposure = option_find_float(options, "exposure",1);
int batch,h,w,c;
h = params.h;
@@ -195,7 +197,7 @@
batch=params.batch;
if(!(h && w && c)) error("Layer before crop layer must output image.");
- crop_layer *layer = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle);
+ crop_layer *layer = make_crop_layer(batch,h,w,c,crop_height,crop_width,flip, angle, saturation, exposure);
option_unused(options);
return layer;
}
--
Gitblit v1.10.0