From c6afc7ff1499fbbe64069e1843d7929bd7ae2eaa Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Thu, 17 Nov 2016 20:18:24 +0000
Subject: [PATCH] :fire: :fire: yolo v2 :fire: :fire:

---
 cfg/coco.data             |    8 
 cfg/yolov1/tiny-coco.cfg  |    0 
 cfg/yolov1/yolo.cfg       |  257 ++++++++++++++++
 cfg/yolov1/yolo-small.cfg |    0 
 cfg/yolov1/yolo.train.cfg |    0 
 src/blas.h                |    1 
 src/region_layer.c        |    2 
 cfg/voc.data              |    6 
 cfg/yolov1/yolo-coco.cfg  |    0 
 cfg/yolov1/yolo2.cfg      |    0 
 cfg/yolo_voc.cfg          |  244 +++++++++++++++
 cfg/yolov1/tiny-yolo.cfg  |    0 
 cfg/imagenet1k.data       |    0 
 src/image.c               |   10 
 cfg/yolov1/xyolo.test.cfg |    0 
 cfg/yolo.cfg              |  245 +++++++--------
 src/demo.c                |    1 
 src/route_layer.c         |    2 
 data/coco.names           |   80 +++++
 src/blas.c                |   22 +
 src/parser.c              |    3 
 src/reorg_layer.c         |   44 --
 src/detector.c            |    2 
 src/darknet.c             |    7 
 24 files changed, 761 insertions(+), 173 deletions(-)

diff --git a/cfg/coco.data b/cfg/coco.data
new file mode 100644
index 0000000..3003841
--- /dev/null
+++ b/cfg/coco.data
@@ -0,0 +1,8 @@
+classes= 80
+train  = /home/pjreddie/data/coco/trainvalno5k.txt
+valid  = coco_testdev
+#valid = data/coco_val_5k.list
+names = data/coco.names
+backup = /home/pjreddie/backup/
+eval=coco
+
diff --git a/cfg/imagenet1k.dataset b/cfg/imagenet1k.data
similarity index 100%
rename from cfg/imagenet1k.dataset
rename to cfg/imagenet1k.data
diff --git a/cfg/voc.data b/cfg/voc.data
new file mode 100644
index 0000000..8246b3a
--- /dev/null
+++ b/cfg/voc.data
@@ -0,0 +1,6 @@
+classes= 20
+train  = /home/pjreddie/data/voc/train.txt
+valid  = /home/pjreddie/data/voc/2007_test.txt
+names = data/pascal.names
+backup = /home/pjreddie/backup/
+
diff --git a/cfg/yolo.cfg b/cfg/yolo.cfg
index c4f415c..4bf904c 100644
--- a/cfg/yolo.cfg
+++ b/cfg/yolo.cfg
@@ -1,36 +1,25 @@
 [net]
-batch=1
-subdivisions=1
-height=448
-width=448
+batch=64
+subdivisions=8
+height=416
+width=416
 channels=3
 momentum=0.9
 decay=0.0005
-saturation=1.5
-exposure=1.5
+angle=0
+saturation = 1.5
+exposure = 1.5
 hue=.1
 
-learning_rate=0.0005
+learning_rate=0.001
+max_batches = 120000
 policy=steps
-steps=200,400,600,20000,30000
-scales=2.5,2,2,.1,.1
-max_batches = 40000
+steps=-1,100,80000,100000
+scales=.1,10,.1,.1
 
 [convolutional]
 batch_normalize=1
-filters=64
-size=7
-stride=2
-pad=1
-activation=leaky
-
-[maxpool]
-size=2
-stride=2
-
-[convolutional]
-batch_normalize=1
-filters=192
+filters=32
 size=3
 stride=1
 pad=1
@@ -42,6 +31,54 @@
 
 [convolutional]
 batch_normalize=1
+filters=64
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
 filters=128
 size=1
 stride=1
@@ -56,6 +93,34 @@
 pad=1
 activation=leaky
 
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
 [convolutional]
 batch_normalize=1
 filters=256
@@ -78,88 +143,12 @@
 
 [convolutional]
 batch_normalize=1
-filters=256
-size=1
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=512
-size=3
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=256
-size=1
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=512
-size=3
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=256
-size=1
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=512
-size=3
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=256
-size=1
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=512
-size=3
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
-filters=512
-size=1
-stride=1
-pad=1
-activation=leaky
-
-[convolutional]
-batch_normalize=1
 filters=1024
 size=3
 stride=1
 pad=1
 activation=leaky
 
-[maxpool]
-size=2
-stride=2
-
 [convolutional]
 batch_normalize=1
 filters=512
@@ -192,6 +181,7 @@
 pad=1
 activation=leaky
 
+
 #######
 
 [convolutional]
@@ -205,10 +195,19 @@
 [convolutional]
 batch_normalize=1
 size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[route]
+layers=-9
+
+[reorg]
 stride=2
-pad=1
-filters=1024
-activation=leaky
+
+[route]
+layers=-1,-3
 
 [convolutional]
 batch_normalize=1
@@ -219,39 +218,27 @@
 activation=leaky
 
 [convolutional]
-batch_normalize=1
-size=3
+size=1
 stride=1
 pad=1
-filters=1024
-activation=leaky
-
-[local]
-size=3
-stride=1
-pad=1
-filters=256
-activation=leaky
-
-[dropout]
-probability=.5
-
-[connected]
-output= 1715
+filters=425
 activation=linear
 
-[detection]
-classes=20
+[region]
+anchors = 0.738768,0.874946,  2.42204,2.65704,  4.30971,7.04493,  10.246,4.59428,  12.6868,11.8741
+bias_match=1
+classes=80
 coords=4
-rescore=1
-side=7
-num=3
-softmax=0
-sqrt=1
+num=5
+softmax=1
 jitter=.2
+rescore=1
 
-object_scale=1
-noobject_scale=.5
+object_scale=5
+noobject_scale=1
 class_scale=1
-coord_scale=5
+coord_scale=1
 
+absolute=1
+thresh = .6
+random=0
diff --git a/cfg/yolo_voc.cfg b/cfg/yolo_voc.cfg
new file mode 100644
index 0000000..ceb3f2a
--- /dev/null
+++ b/cfg/yolo_voc.cfg
@@ -0,0 +1,244 @@
+[net]
+batch=64
+subdivisions=8
+height=416
+width=416
+channels=3
+momentum=0.9
+decay=0.0005
+angle=0
+saturation = 1.5
+exposure = 1.5
+hue=.1
+
+learning_rate=0.0001
+max_batches = 45000
+policy=steps
+steps=100,25000,35000
+scales=10,.1,.1
+
+[convolutional]
+batch_normalize=1
+filters=32
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+
+#######
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[route]
+layers=-9
+
+[reorg]
+stride=2
+
+[route]
+layers=-1,-3
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+size=1
+stride=1
+pad=1
+filters=125
+activation=linear
+
+[region]
+anchors = 1.08,1.19,  3.42,4.41,  6.63,11.38,  9.42,5.11,  16.62,10.52
+bias_match=1
+classes=20
+coords=4
+num=5
+softmax=1
+jitter=.2
+rescore=1
+
+object_scale=5
+noobject_scale=1
+class_scale=1
+coord_scale=1
+
+absolute=1
+thresh = .6
+random=0
diff --git a/cfg/tiny-coco.cfg b/cfg/yolov1/tiny-coco.cfg
similarity index 100%
rename from cfg/tiny-coco.cfg
rename to cfg/yolov1/tiny-coco.cfg
diff --git a/cfg/tiny-yolo.cfg b/cfg/yolov1/tiny-yolo.cfg
similarity index 100%
rename from cfg/tiny-yolo.cfg
rename to cfg/yolov1/tiny-yolo.cfg
diff --git a/cfg/xyolo.test.cfg b/cfg/yolov1/xyolo.test.cfg
similarity index 100%
rename from cfg/xyolo.test.cfg
rename to cfg/yolov1/xyolo.test.cfg
diff --git a/cfg/yolo-coco.cfg b/cfg/yolov1/yolo-coco.cfg
similarity index 100%
rename from cfg/yolo-coco.cfg
rename to cfg/yolov1/yolo-coco.cfg
diff --git a/cfg/yolo-small.cfg b/cfg/yolov1/yolo-small.cfg
similarity index 100%
rename from cfg/yolo-small.cfg
rename to cfg/yolov1/yolo-small.cfg
diff --git a/cfg/yolov1/yolo.cfg b/cfg/yolov1/yolo.cfg
new file mode 100644
index 0000000..c4f415c
--- /dev/null
+++ b/cfg/yolov1/yolo.cfg
@@ -0,0 +1,257 @@
+[net]
+batch=1
+subdivisions=1
+height=448
+width=448
+channels=3
+momentum=0.9
+decay=0.0005
+saturation=1.5
+exposure=1.5
+hue=.1
+
+learning_rate=0.0005
+policy=steps
+steps=200,400,600,20000,30000
+scales=2.5,2,2,.1,.1
+max_batches = 40000
+
+[convolutional]
+batch_normalize=1
+filters=64
+size=7
+stride=2
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=192
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=128
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=256
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[maxpool]
+size=2
+stride=2
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=512
+size=1
+stride=1
+pad=1
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+filters=1024
+size=3
+stride=1
+pad=1
+activation=leaky
+
+#######
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=2
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[convolutional]
+batch_normalize=1
+size=3
+stride=1
+pad=1
+filters=1024
+activation=leaky
+
+[local]
+size=3
+stride=1
+pad=1
+filters=256
+activation=leaky
+
+[dropout]
+probability=.5
+
+[connected]
+output= 1715
+activation=linear
+
+[detection]
+classes=20
+coords=4
+rescore=1
+side=7
+num=3
+softmax=0
+sqrt=1
+jitter=.2
+
+object_scale=1
+noobject_scale=.5
+class_scale=1
+coord_scale=5
+
diff --git a/cfg/yolo.train.cfg b/cfg/yolov1/yolo.train.cfg
similarity index 100%
rename from cfg/yolo.train.cfg
rename to cfg/yolov1/yolo.train.cfg
diff --git a/cfg/yolo2.cfg b/cfg/yolov1/yolo2.cfg
similarity index 100%
rename from cfg/yolo2.cfg
rename to cfg/yolov1/yolo2.cfg
diff --git a/data/coco.names b/data/coco.names
new file mode 100644
index 0000000..ca76c80
--- /dev/null
+++ b/data/coco.names
@@ -0,0 +1,80 @@
+person
+bicycle
+car
+motorbike
+aeroplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+backpack
+umbrella
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+wine glass
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+sofa
+pottedplant
+bed
+diningtable
+toilet
+tvmonitor
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
diff --git a/src/blas.c b/src/blas.c
index c6d59ea..31bd86b 100644
--- a/src/blas.c
+++ b/src/blas.c
@@ -5,6 +5,28 @@
 #include <stdio.h>
 #include <stdlib.h>
 #include <string.h>
+void reorg_cpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out)
+{
+    int b,i,j,k;
+    int out_c = c/(stride*stride);
+
+    for(b = 0; b < batch; ++b){
+        for(k = 0; k < c; ++k){
+            for(j = 0; j < h; ++j){
+                for(i = 0; i < w; ++i){
+                    int in_index  = i + w*(j + h*(k + c*b));
+                    int c2 = k % out_c;
+                    int offset = k / out_c;
+                    int w2 = i*stride + offset % stride;
+                    int h2 = j*stride + offset / stride;
+                    int out_index = w2 + w*stride*(h2 + h*stride*(c2 + out_c*b));
+                    if(forward) out[out_index] = x[in_index];
+                    else out[in_index] = x[out_index];
+                }
+            }
+        }
+    }
+}
 
 void flatten(float *x, int size, int layers, int batch, int forward)
 {
diff --git a/src/blas.h b/src/blas.h
index a942024..3d6ee7d 100644
--- a/src/blas.h
+++ b/src/blas.h
@@ -4,6 +4,7 @@
 void pm(int M, int N, float *A);
 float *random_matrix(int rows, int cols);
 void time_random_matrix(int TA, int TB, int m, int k, int n);
+void reorg_cpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
 
 void test_blas();
 
diff --git a/src/darknet.c b/src/darknet.c
index 1444756..776778a 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -13,6 +13,7 @@
 #endif
 
 extern void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top);
+extern void test_detector(char *datacfg, char *cfgfile, char *weightfile, char *filename, float thresh);
 extern void run_voxel(int argc, char **argv);
 extern void run_yolo(int argc, char **argv);
 extern void run_detector(int argc, char **argv);
@@ -379,6 +380,10 @@
         run_super(argc, argv);
     } else if (0 == strcmp(argv[1], "detector")){
         run_detector(argc, argv);
+    } else if (0 == strcmp(argv[1], "detect")){
+        float thresh = find_float_arg(argc, argv, "-thresh", .25);
+        char *filename = (argc > 4) ? argv[4]: 0;
+        test_detector("cfg/coco.data", argv[2], argv[3], filename, thresh);
     } else if (0 == strcmp(argv[1], "cifar")){
         run_cifar(argc, argv);
     } else if (0 == strcmp(argv[1], "go")){
@@ -390,7 +395,7 @@
     } else if (0 == strcmp(argv[1], "coco")){
         run_coco(argc, argv);
     } else if (0 == strcmp(argv[1], "classify")){
-        predict_classifier("cfg/imagenet1k.dataset", argv[2], argv[3], argv[4], 5);
+        predict_classifier("cfg/imagenet1k.data", argv[2], argv[3], argv[4], 5);
     } else if (0 == strcmp(argv[1], "classifier")){
         run_classifier(argc, argv);
     } else if (0 == strcmp(argv[1], "art")){
diff --git a/src/demo.c b/src/demo.c
index 91562c0..915d950 100644
--- a/src/demo.c
+++ b/src/demo.c
@@ -110,6 +110,7 @@
     srand(2222222);
 
     if(filename){
+        printf("video file: %s\n", filename);
         cap = cvCaptureFromFile(filename);
     }else{
         cap = cvCaptureFromCAM(cam_index);
diff --git a/src/detector.c b/src/detector.c
index a513816..50db65b 100644
--- a/src/detector.c
+++ b/src/detector.c
@@ -490,7 +490,7 @@
 void run_detector(int argc, char **argv)
 {
     char *prefix = find_char_arg(argc, argv, "-prefix", 0);
-    float thresh = find_float_arg(argc, argv, "-thresh", .2);
+    float thresh = find_float_arg(argc, argv, "-thresh", .25);
     int cam_index = find_int_arg(argc, argv, "-c", 0);
     int frame_skip = find_int_arg(argc, argv, "-s", 0);
     if(argc < 4){
diff --git a/src/image.c b/src/image.c
index ad7c2d6..e744782 100644
--- a/src/image.c
+++ b/src/image.c
@@ -185,10 +185,16 @@
         int class = max_index(probs[i], classes);
         float prob = probs[i][class];
         if(prob > thresh){
-            //int width = pow(prob, 1./2.)*30+1;
+
             int width = im.h * .012;
+
+            if(0){
+                width = pow(prob, 1./2.)*10+1;
+                alphabet = 0;
+            }
+
             printf("%s: %.0f%%\n", names[class], prob*100);
-            int offset = class*1 % classes;
+            int offset = class*123457 % classes;
             float red = get_color(2,offset,classes);
             float green = get_color(1,offset,classes);
             float blue = get_color(0,offset,classes);
diff --git a/src/parser.c b/src/parser.c
index cde06b4..84733d7 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -238,9 +238,6 @@
     int classes = option_find_int(options, "classes", 20);
     int num = option_find_int(options, "num", 1);
 
-    params.w = option_find_int(options, "side", params.w);
-    params.h = option_find_int(options, "side", params.h);
-
     layer l = make_region_layer(params.batch, params.w, params.h, num, classes, coords);
     assert(l.outputs == params.inputs);
 
diff --git a/src/region_layer.c b/src/region_layer.c
index 5e8387d..902778c 100644
--- a/src/region_layer.c
+++ b/src/region_layer.c
@@ -44,7 +44,7 @@
     l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
 #endif
 
-    fprintf(stderr, "Region Layer\n");
+    fprintf(stderr, "detection\n");
     srand(0);
 
     return l;
diff --git a/src/reorg_layer.c b/src/reorg_layer.c
index d93dd97..9b68f03 100644
--- a/src/reorg_layer.c
+++ b/src/reorg_layer.c
@@ -23,7 +23,7 @@
         l.out_c = c*(stride*stride);
     }
     l.reverse = reverse;
-    fprintf(stderr, "Reorg Layer: %d x %d x %d image -> %d x %d x %d image, \n", w,h,c,l.out_w, l.out_h, l.out_c);
+    fprintf(stderr, "reorg              /%2d  %4d x%4d x%4d   ->  %4d x%4d x%4d\n",  stride, w, h, c, l.out_w, l.out_h, l.out_c);
     l.outputs = l.out_h * l.out_w * l.out_c;
     l.inputs = h*w*c;
     int output_size = l.out_h * l.out_w * l.out_c * batch;
@@ -77,45 +77,19 @@
 
 void forward_reorg_layer(const layer l, network_state state)
 {
-    int b,i,j,k;
-
-    for(b = 0; b < l.batch; ++b){
-        for(k = 0; k < l.c; ++k){
-            for(j = 0; j < l.h; ++j){
-                for(i = 0; i < l.w; ++i){
-                    int in_index  = i + l.w*(j + l.h*(k + l.c*b));
-
-                    int c2 = k % l.out_c;
-                    int offset = k / l.out_c;
-                    int w2 = i*l.stride + offset % l.stride;
-                    int h2 = j*l.stride + offset / l.stride;
-                    int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b));
-                    l.output[out_index] = state.input[in_index];
-                }
-            }
-        }
+    if(l.reverse){
+        reorg_cpu(state.input, l.w, l.h, l.c, l.batch, l.stride, 1, l.output);
+    }else {
+        reorg_cpu(state.input, l.w, l.h, l.c, l.batch, l.stride, 0, l.output);
     }
 }
 
 void backward_reorg_layer(const layer l, network_state state)
 {
-    int b,i,j,k;
-
-    for(b = 0; b < l.batch; ++b){
-        for(k = 0; k < l.c; ++k){
-            for(j = 0; j < l.h; ++j){
-                for(i = 0; i < l.w; ++i){
-                    int in_index  = i + l.w*(j + l.h*(k + l.c*b));
-
-                    int c2 = k % l.out_c;
-                    int offset = k / l.out_c;
-                    int w2 = i*l.stride + offset % l.stride;
-                    int h2 = j*l.stride + offset / l.stride;
-                    int out_index = w2 + l.out_w*(h2 + l.out_h*(c2 + l.out_c*b));
-                    state.delta[in_index] = l.delta[out_index];
-                }
-            }
-        }
+    if(l.reverse){
+        reorg_cpu(l.delta, l.w, l.h, l.c, l.batch, l.stride, 0, state.delta);
+    }else{
+        reorg_cpu(l.delta, l.w, l.h, l.c, l.batch, l.stride, 1, state.delta);
     }
 }
 
diff --git a/src/route_layer.c b/src/route_layer.c
index d18427a..dce7118 100644
--- a/src/route_layer.c
+++ b/src/route_layer.c
@@ -5,7 +5,7 @@
 
 route_layer make_route_layer(int batch, int n, int *input_layers, int *input_sizes)
 {
-    fprintf(stderr,"Route Layer:");
+    fprintf(stderr,"route ");
     route_layer l = {0};
     l.type = ROUTE;
     l.batch = batch;

--
Gitblit v1.10.0