From dcb000b553d051429a49c8729dc5b1af632e8532 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 12 Mar 2015 05:20:15 +0000
Subject: [PATCH] refactoring and added DARK ZONE
---
src/dropout_layer_kernels.cu | 18 +++++++++---------
1 files changed, 9 insertions(+), 9 deletions(-)
diff --git a/src/dropout_layer_kernels.cu b/src/dropout_layer_kernels.cu
index 371f0dc..94f61ab 100644
--- a/src/dropout_layer_kernels.cu
+++ b/src/dropout_layer_kernels.cu
@@ -2,32 +2,32 @@
#include "dropout_layer.h"
#include "cuda.h"
#include "utils.h"
+#include "params.h"
}
-__global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand, float prob, float scale, float *output)
+__global__ void yoloswag420blazeit360noscope(float *input, int size, float *rand, float prob, float scale)
{
int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
- if(id < size) output[id] = (rand[id] < prob) ? 0 : input[id]*scale;
+ if(id < size) input[id] = (rand[id] < prob) ? 0 : input[id]*scale;
}
-extern "C" void forward_dropout_layer_gpu(dropout_layer layer, float * input)
+extern "C" void forward_dropout_layer_gpu(dropout_layer layer, network_state state)
{
+ if (!state.train) return;
int j;
int size = layer.inputs*layer.batch;
for(j = 0; j < size; ++j) layer.rand[j] = rand_uniform();
cuda_push_array(layer.rand_gpu, layer.rand, layer.inputs*layer.batch);
- yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(input, size, layer.rand_gpu, layer.probability,
- layer.scale, layer.output_gpu);
+ yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.input, size, layer.rand_gpu, layer.probability, layer.scale);
check_error(cudaPeekAtLastError());
}
-extern "C" void backward_dropout_layer_gpu(dropout_layer layer, float *delta)
+extern "C" void backward_dropout_layer_gpu(dropout_layer layer, network_state state)
{
- if(!delta) return;
+ if(!state.delta) return;
int size = layer.inputs*layer.batch;
- yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(delta, size, layer.rand_gpu, layer.probability,
- layer.scale, delta);
+ yoloswag420blazeit360noscope<<<cuda_gridsize(size), BLOCK>>>(state.delta, size, layer.rand_gpu, layer.probability, layer.scale);
check_error(cudaPeekAtLastError());
}
--
Gitblit v1.10.0