From b8e6e80c6d411d05a9e09f1e3676eb9a7f3ea0e8 Mon Sep 17 00:00:00 2001
From: AlexeyAB <alexeyab84@gmail.com>
Date: Fri, 03 Aug 2018 11:35:03 +0000
Subject: [PATCH] Added spatial Yolo v3 yolov3-spp.cfg

---
 src/cuda.c                           |    8 
 build/darknet/x64/cfg/yolov3-spp.cfg |  822 +++++++++++++++++++++++++++++
 cfg/yolov3-spp.cfg                   |  822 +++++++++++++++++++++++++++++
 src/darknet.c                        |    4 
 src/maxpool_layer.c                  |    8 
 src/maxpool_layer_kernels.cu         |    8 
 6 files changed, 1,658 insertions(+), 14 deletions(-)

diff --git a/build/darknet/x64/cfg/yolov3-spp.cfg b/build/darknet/x64/cfg/yolov3-spp.cfg
new file mode 100644
index 0000000..4ad2a05
--- /dev/null
+++ b/build/darknet/x64/cfg/yolov3-spp.cfg
@@ -0,0 +1,822 @@
+[net]
+# Testing
+batch=1
+subdivisions=1
+# Training
+# batch=64
+# subdivisions=16
+width=608
+height=608
+channels=3
+momentum=0.9
+decay=0.0005
+angle=0
+saturation = 1.5
+exposure = 1.5
+hue=.1
+
+learning_rate=0.001
+burn_in=1000
+max_batches = 500200
+policy=steps
+steps=400000,450000
+scales=.1,.1
+
+[convolutional]
+batch_normalize=1
+filters=32
+size=3
+stride=1
+pad=1
+activation=leaky
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=32
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+######################
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+### SPP ###
+[maxpool]
+stride=1
+size=5
+
+[route]
+layers=-2
+
+[maxpool]
+stride=1
+size=9
+
+[route]
+layers=-4
+
+[maxpool]
+stride=1
+size=13
+
+[route]
+layers=-1,-3,-5,-6
+
+### End SPP ###
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 6,7,8
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
+
+[route]
+layers = -4
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[upsample]
+stride=2
+
+[route]
+layers = -1, 61
+
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 3,4,5
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
+
+
+[route]
+layers = -4
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[upsample]
+stride=2
+
+[route]
+layers = -1, 36
+
+
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 0,1,2
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
diff --git a/cfg/yolov3-spp.cfg b/cfg/yolov3-spp.cfg
new file mode 100644
index 0000000..4ad2a05
--- /dev/null
+++ b/cfg/yolov3-spp.cfg
@@ -0,0 +1,822 @@
+[net]
+# Testing
+batch=1
+subdivisions=1
+# Training
+# batch=64
+# subdivisions=16
+width=608
+height=608
+channels=3
+momentum=0.9
+decay=0.0005
+angle=0
+saturation = 1.5
+exposure = 1.5
+hue=.1
+
+learning_rate=0.001
+burn_in=1000
+max_batches = 500200
+policy=steps
+steps=400000,450000
+scales=.1,.1
+
+[convolutional]
+batch_normalize=1
+filters=32
+size=3
+stride=1
+pad=1
+activation=leaky
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=32
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+# Downsample
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=2
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[shortcut]
+from=-3
+activation=linear
+
+######################
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+### SPP ###
+[maxpool]
+stride=1
+size=5
+
+[route]
+layers=-2
+
+[maxpool]
+stride=1
+size=9
+
+[route]
+layers=-4
+
+[maxpool]
+stride=1
+size=13
+
+[route]
+layers=-1,-3,-5,-6
+
+### End SPP ###
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 6,7,8
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
+
+[route]
+layers = -4
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[upsample]
+stride=2
+
+[route]
+layers = -1, 61
+
+
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=512
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 3,4,5
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
+
+
+[route]
+layers = -4
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[upsample]
+stride=2
+
+[route]
+layers = -1, 36
+
+
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=255
+activation=linear
+
+
+[yolo]
+mask = 0,1,2
+anchors = 10,13,  16,30,  33,23,  30,61,  62,45,  59,119,  116,90,  156,198,  373,326
+classes=80
+num=9
+jitter=.3
+ignore_thresh = .7
+truth_thresh = 1
+random=1
+
diff --git a/src/cuda.c b/src/cuda.c
index 688e1c5..2284dad 100644
--- a/src/cuda.c
+++ b/src/cuda.c
@@ -29,23 +29,23 @@
     //cudaDeviceSynchronize();
     cudaError_t status2 = cudaGetLastError();
     if (status != cudaSuccess)
-    {   
+    {
         const char *s = cudaGetErrorString(status);
         char buffer[256];
         printf("CUDA Error: %s\n", s);
         assert(0);
         snprintf(buffer, 256, "CUDA Error: %s", s);
         error(buffer);
-    } 
+    }
     if (status2 != cudaSuccess)
-    {   
+    {
         const char *s = cudaGetErrorString(status);
         char buffer[256];
         printf("CUDA Error Prev: %s\n", s);
         assert(0);
         snprintf(buffer, 256, "CUDA Error Prev: %s", s);
         error(buffer);
-    } 
+    }
 }
 
 dim3 cuda_gridsize(size_t n){
diff --git a/src/darknet.c b/src/darknet.c
index f6d61c0..068c9e0 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -40,13 +40,13 @@
     network net = parse_network_cfg(cfgfile);
     network sum = parse_network_cfg(cfgfile);
 
-    char *weightfile = argv[4];   
+    char *weightfile = argv[4];
     load_weights(&sum, weightfile);
 
     int i, j;
     int n = argc - 5;
     for(i = 0; i < n; ++i){
-        weightfile = argv[i+5];   
+        weightfile = argv[i+5];
         load_weights(&net, weightfile);
         for(j = 0; j < net.n; ++j){
             layer l = net.layers[j];
diff --git a/src/maxpool_layer.c b/src/maxpool_layer.c
index 41f7a79..0eeb467 100644
--- a/src/maxpool_layer.c
+++ b/src/maxpool_layer.c
@@ -27,8 +27,8 @@
     l.w = w;
     l.c = c;
     l.pad = padding;
-    l.out_w = (w + 2*padding)/stride;
-    l.out_h = (h + 2*padding)/stride;
+    l.out_w = (w + 2 * padding - size) / stride + 1;
+    l.out_h = (h + 2 * padding - size) / stride + 1;
     l.out_c = c;
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = h*w*c;
@@ -58,8 +58,8 @@
     l->w = w;
     l->inputs = h*w*l->c;
 
-    l->out_w = (w + 2*l->pad)/l->stride;
-    l->out_h = (h + 2*l->pad)/l->stride;
+    l->out_w = (w + 2 * l->pad - l->size) / l->stride + 1;
+    l->out_h = (h + 2 * l->pad - l->size) / l->stride + 1;
     l->outputs = l->out_w * l->out_h * l->c;
     int output_size = l->outputs * l->batch;
 
diff --git a/src/maxpool_layer_kernels.cu b/src/maxpool_layer_kernels.cu
index d40d3c0..05b1f9a 100644
--- a/src/maxpool_layer_kernels.cu
+++ b/src/maxpool_layer_kernels.cu
@@ -9,8 +9,8 @@
 
 __global__ void forward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *input, float *output, int *indexes)
 {
-    int h = (in_h + 2*pad)/stride;
-    int w = (in_w + 2*pad)/stride;
+    int h = (in_h + 2 * pad - size) / stride + 1;
+    int w = (in_w + 2 * pad - size) / stride + 1;
     int c = in_c;
 
     int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
@@ -49,8 +49,8 @@
 
 __global__ void backward_maxpool_layer_kernel(int n, int in_h, int in_w, int in_c, int stride, int size, int pad, float *delta, float *prev_delta, int *indexes)
 {
-    int h = (in_h + 2*pad)/stride;
-    int w = (in_w + 2*pad)/stride;
+    int h = (in_h + 2 * pad - size) / stride + 1;
+    int w = (in_w + 2 * pad - size) / stride + 1;
     int c = in_c;
     int area = (size-1)/stride;
 

--
Gitblit v1.10.0