From 8c5364f58569eaeb5582a4915b36b24fc5570c76 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 09 Nov 2015 19:31:39 +0000
Subject: [PATCH] New YOLO
---
src/network.c | 11
src/yolo.c | 352 ++++++++++++-----
src/network.h | 1
Makefile | 5
src/network_kernels.cu | 7
src/detection_layer.h | 8
src/detection_layer.c | 296 +++++++-------
src/yolo_kernels.cu | 15
data/scream.jpg | 0
/dev/null | 460 -----------------------
src/coco.c | 4
cfg/yolo.cfg | 6
src/parser.c | 28 -
src/darknet.c | 3
src/layer.h | 4
15 files changed, 412 insertions(+), 788 deletions(-)
diff --git a/Makefile b/Makefile
index 2180094..412e826 100644
--- a/Makefile
+++ b/Makefile
@@ -3,7 +3,6 @@
DEBUG=0
ARCH= --gpu-architecture=compute_20 --gpu-code=compute_20
-ARCH= -arch=sm_52 --use_fast_math
VPATH=./src/
EXEC=darknet
@@ -35,9 +34,9 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
-OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o classifier.o
+OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o layer.o compare.o classifier.o
ifeq ($(GPU), 1)
-OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o swag_kernels.o
+OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o yolo_kernels.o
endif
OBJS = $(addprefix $(OBJDIR), $(OBJ))
diff --git a/cfg/yolo.cfg b/cfg/yolo.cfg
index ee16726..a652f46 100644
--- a/cfg/yolo.cfg
+++ b/cfg/yolo.cfg
@@ -9,8 +9,8 @@
learning_rate=0.001
policy=steps
-steps=100,200,300,400,500,600,700,20000,30000
-scales=2,2,1.25,1.25,1.25,1.25,1.03,.1,.1
+steps=200,400,600,20000,30000
+scales=2.5,2,2,.1,.1
max_batches = 40000
[crop]
@@ -218,7 +218,7 @@
output= 1470
activation=linear
-[region]
+[detection]
classes=20
coords=4
rescore=1
diff --git a/data/scream.jpg b/data/scream.jpg
index 40aea94..43f2c36 100644
--- a/data/scream.jpg
+++ b/data/scream.jpg
Binary files differ
diff --git a/src/coco.c b/src/coco.c
index e30eeb7..aadf09d 100644
--- a/src/coco.c
+++ b/src/coco.c
@@ -1,7 +1,7 @@
#include <stdio.h>
#include "network.h"
-#include "region_layer.h"
+#include "detection_layer.h"
#include "cost_layer.h"
#include "utils.h"
#include "parser.h"
@@ -366,7 +366,7 @@
if(weightfile){
load_weights(&net, weightfile);
}
- region_layer l = net.layers[net.n-1];
+ detection_layer l = net.layers[net.n-1];
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
diff --git a/src/darknet.c b/src/darknet.c
index 7814611..c2a6596 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -13,7 +13,6 @@
extern void run_imagenet(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
-extern void run_swag(int argc, char **argv);
extern void run_coco(int argc, char **argv);
extern void run_writing(int argc, char **argv);
extern void run_captcha(int argc, char **argv);
@@ -221,8 +220,6 @@
average(argc, argv);
} else if (0 == strcmp(argv[1], "yolo")){
run_yolo(argc, argv);
- } else if (0 == strcmp(argv[1], "swag")){
- run_swag(argc, argv);
} else if (0 == strcmp(argv[1], "coco")){
run_coco(argc, argv);
} else if (0 == strcmp(argv[1], "classifier")){
diff --git a/src/detection_layer.c b/src/detection_layer.c
index daeee04..33f4f0b 100644
--- a/src/detection_layer.c
+++ b/src/detection_layer.c
@@ -6,42 +6,32 @@
#include "cuda.h"
#include "utils.h"
#include <stdio.h>
+#include <assert.h>
#include <string.h>
#include <stdlib.h>
-int get_detection_layer_locations(detection_layer l)
-{
- return l.inputs / (l.classes+l.coords+l.joint+(l.background || l.objectness));
-}
-
-int get_detection_layer_output_size(detection_layer l)
-{
- return get_detection_layer_locations(l)*((l.background || l.objectness) + l.classes + l.coords);
-}
-
-detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness)
+detection_layer make_detection_layer(int batch, int inputs, int n, int side, int classes, int coords, int rescore)
{
detection_layer l = {0};
l.type = DETECTION;
-
+
+ l.n = n;
l.batch = batch;
l.inputs = inputs;
l.classes = classes;
l.coords = coords;
l.rescore = rescore;
- l.objectness = objectness;
- l.background = background;
- l.joint = joint;
+ l.side = side;
+ assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
l.cost = calloc(1, sizeof(float));
- l.does_cost=1;
- int outputs = get_detection_layer_output_size(l);
- l.outputs = outputs;
- l.output = calloc(batch*outputs, sizeof(float));
- l.delta = calloc(batch*outputs, sizeof(float));
- #ifdef GPU
- l.output_gpu = cuda_make_array(l.output, batch*outputs);
- l.delta_gpu = cuda_make_array(l.delta, batch*outputs);
- #endif
+ l.outputs = l.inputs;
+ l.truths = l.side*l.side*(1+l.coords+l.classes);
+ l.output = calloc(batch*l.outputs, sizeof(float));
+ l.delta = calloc(batch*l.outputs, sizeof(float));
+#ifdef GPU
+ l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
+ l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
+#endif
fprintf(stderr, "Detection Layer\n");
srand(0);
@@ -51,124 +41,164 @@
void forward_detection_layer(const detection_layer l, network_state state)
{
- int in_i = 0;
- int out_i = 0;
- int locations = get_detection_layer_locations(l);
+ int locations = l.side*l.side;
int i,j;
- for(i = 0; i < l.batch*locations; ++i){
- int mask = (!state.truth || state.truth[out_i + (l.background || l.objectness) + l.classes + 2]);
- float scale = 1;
- if(l.joint) scale = state.input[in_i++];
- else if(l.objectness){
- l.output[out_i++] = 1-state.input[in_i++];
- scale = mask;
- }
- else if(l.background) l.output[out_i++] = scale*state.input[in_i++];
-
- for(j = 0; j < l.classes; ++j){
- l.output[out_i++] = scale*state.input[in_i++];
- }
- if(l.objectness){
-
- }else if(l.background){
- softmax_array(l.output + out_i - l.classes-l.background, l.classes+l.background, l.output + out_i - l.classes-l.background);
- activate_array(state.input+in_i, l.coords, LOGISTIC);
- }
- for(j = 0; j < l.coords; ++j){
- l.output[out_i++] = mask*state.input[in_i++];
+ memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
+ int b;
+ if (l.softmax){
+ for(b = 0; b < l.batch; ++b){
+ int index = b*l.inputs;
+ for (i = 0; i < locations; ++i) {
+ int offset = i*l.classes;
+ softmax_array(l.output + index + offset, l.classes,
+ l.output + index + offset);
+ }
+ int offset = locations*l.classes;
+ activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
}
}
- float avg_iou = 0;
- int count = 0;
- if(l.does_cost && state.train){
+ if(state.train){
+ float avg_iou = 0;
+ float avg_cat = 0;
+ float avg_allcat = 0;
+ float avg_obj = 0;
+ float avg_anyobj = 0;
+ int count = 0;
*(l.cost) = 0;
- int size = get_detection_layer_output_size(l) * l.batch;
+ int size = l.inputs * l.batch;
memset(l.delta, 0, size * sizeof(float));
- for (i = 0; i < l.batch*locations; ++i) {
- int classes = (l.objectness || l.background)+l.classes;
- int offset = i*(classes+l.coords);
- for (j = offset; j < offset+classes; ++j) {
- *(l.cost) += pow(state.truth[j] - l.output[j], 2);
- l.delta[j] = state.truth[j] - l.output[j];
- if(l.background && j == offset) l.delta[j] *= .1;
- }
-
- box truth;
- truth.x = state.truth[j+0]/7;
- truth.y = state.truth[j+1]/7;
- truth.w = pow(state.truth[j+2], 2);
- truth.h = pow(state.truth[j+3], 2);
-
- box out;
- out.x = l.output[j+0]/7;
- out.y = l.output[j+1]/7;
- out.w = pow(l.output[j+2], 2);
- out.h = pow(l.output[j+3], 2);
-
- if(!(truth.w*truth.h)) continue;
- float iou = box_iou(out, truth);
- avg_iou += iou;
- ++count;
-
- *(l.cost) += pow((1-iou), 2);
- l.delta[j+0] = 4 * (state.truth[j+0] - l.output[j+0]);
- l.delta[j+1] = 4 * (state.truth[j+1] - l.output[j+1]);
- l.delta[j+2] = 4 * (state.truth[j+2] - l.output[j+2]);
- l.delta[j+3] = 4 * (state.truth[j+3] - l.output[j+3]);
- if(l.rescore){
- if(l.objectness){
- state.truth[offset] = iou;
- l.delta[offset] = state.truth[offset] - l.output[offset];
+ for (b = 0; b < l.batch; ++b){
+ int index = b*l.inputs;
+ for (i = 0; i < locations; ++i) {
+ int truth_index = (b*locations + i)*(1+l.coords+l.classes);
+ int is_obj = state.truth[truth_index];
+ for (j = 0; j < l.n; ++j) {
+ int p_index = index + locations*l.classes + i*l.n + j;
+ l.delta[p_index] = l.noobject_scale*(0 - l.output[p_index]);
+ *(l.cost) += l.noobject_scale*pow(l.output[p_index], 2);
+ avg_anyobj += l.output[p_index];
}
- else{
- for (j = offset; j < offset+classes; ++j) {
- if(state.truth[j]) state.truth[j] = iou;
- l.delta[j] = state.truth[j] - l.output[j];
+
+ int best_index = -1;
+ float best_iou = 0;
+ float best_rmse = 20;
+
+ if (!is_obj){
+ continue;
+ }
+
+ int class_index = index + i*l.classes;
+ for(j = 0; j < l.classes; ++j) {
+ l.delta[class_index+j] = l.class_scale * (state.truth[truth_index+1+j] - l.output[class_index+j]);
+ *(l.cost) += l.class_scale * pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
+ if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
+ avg_allcat += l.output[class_index+j];
+ }
+
+ box truth = float_to_box(state.truth + truth_index + 1 + l.classes);
+ truth.x /= l.side;
+ truth.y /= l.side;
+
+ for(j = 0; j < l.n; ++j){
+ int box_index = index + locations*(l.classes + l.n) + (i*l.n + j) * l.coords;
+ box out = float_to_box(l.output + box_index);
+ out.x /= l.side;
+ out.y /= l.side;
+
+ if (l.sqrt){
+ out.w = out.w*out.w;
+ out.h = out.h*out.h;
+ }
+
+ float iou = box_iou(out, truth);
+ //iou = 0;
+ float rmse = box_rmse(out, truth);
+ if(best_iou > 0 || iou > 0){
+ if(iou > best_iou){
+ best_iou = iou;
+ best_index = j;
+ }
+ }else{
+ if(rmse < best_rmse){
+ best_rmse = rmse;
+ best_index = j;
+ }
}
}
+
+ if(l.forced){
+ if(truth.w*truth.h < .1){
+ best_index = 1;
+ }else{
+ best_index = 0;
+ }
+ }
+
+ int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords;
+ int tbox_index = truth_index + 1 + l.classes;
+
+ box out = float_to_box(l.output + box_index);
+ out.x /= l.side;
+ out.y /= l.side;
+ if (l.sqrt) {
+ out.w = out.w*out.w;
+ out.h = out.h*out.h;
+ }
+ float iou = box_iou(out, truth);
+
+ //printf("%d", best_index);
+ int p_index = index + locations*l.classes + i*l.n + best_index;
+ *(l.cost) -= l.noobject_scale * pow(l.output[p_index], 2);
+ *(l.cost) += l.object_scale * pow(1-l.output[p_index], 2);
+ avg_obj += l.output[p_index];
+ l.delta[p_index] = l.object_scale * (1.-l.output[p_index]);
+
+ if(l.rescore){
+ l.delta[p_index] = l.object_scale * (iou - l.output[p_index]);
+ }
+
+ l.delta[box_index+0] = l.coord_scale*(state.truth[tbox_index + 0] - l.output[box_index + 0]);
+ l.delta[box_index+1] = l.coord_scale*(state.truth[tbox_index + 1] - l.output[box_index + 1]);
+ l.delta[box_index+2] = l.coord_scale*(state.truth[tbox_index + 2] - l.output[box_index + 2]);
+ l.delta[box_index+3] = l.coord_scale*(state.truth[tbox_index + 3] - l.output[box_index + 3]);
+ if(l.sqrt){
+ l.delta[box_index+2] = l.coord_scale*(sqrt(state.truth[tbox_index + 2]) - l.output[box_index + 2]);
+ l.delta[box_index+3] = l.coord_scale*(sqrt(state.truth[tbox_index + 3]) - l.output[box_index + 3]);
+ }
+
+ *(l.cost) += pow(1-iou, 2);
+ avg_iou += iou;
+ ++count;
+ }
+ if(l.softmax){
+ gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords),
+ LOGISTIC, l.delta + index + locations*l.classes);
}
}
- printf("Avg IOU: %f\n", avg_iou/count);
+ printf("Detection Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
}
}
void backward_detection_layer(const detection_layer l, network_state state)
{
- int locations = get_detection_layer_locations(l);
- int i,j;
- int in_i = 0;
- int out_i = 0;
- for(i = 0; i < l.batch*locations; ++i){
- float scale = 1;
- float latent_delta = 0;
- if(l.joint) scale = state.input[in_i++];
- else if (l.objectness) state.delta[in_i++] += -l.delta[out_i++];
- else if (l.background) state.delta[in_i++] += scale*l.delta[out_i++];
- for(j = 0; j < l.classes; ++j){
- latent_delta += state.input[in_i]*l.delta[out_i];
- state.delta[in_i++] += scale*l.delta[out_i++];
- }
-
- if (l.objectness) {
-
- }else if (l.background) gradient_array(l.output + out_i, l.coords, LOGISTIC, l.delta + out_i);
- for (j = 0; j < l.coords; ++j){
- state.delta[in_i++] += l.delta[out_i++];
- }
- if(l.joint) state.delta[in_i-l.coords-l.classes-l.joint] += latent_delta;
- }
+ axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
}
#ifdef GPU
void forward_detection_layer_gpu(const detection_layer l, network_state state)
{
- int outputs = get_detection_layer_output_size(l);
+ if(!state.train){
+ copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1);
+ return;
+ }
+
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0;
if(state.truth){
- truth_cpu = calloc(l.batch*outputs, sizeof(float));
- cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
+ int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
+ truth_cpu = calloc(num_truth, sizeof(float));
+ cuda_pull_array(state.truth, truth_cpu, num_truth);
}
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
network_state cpu_state;
@@ -176,38 +206,16 @@
cpu_state.truth = truth_cpu;
cpu_state.input = in_cpu;
forward_detection_layer(l, cpu_state);
- cuda_push_array(l.output_gpu, l.output, l.batch*outputs);
- cuda_push_array(l.delta_gpu, l.delta, l.batch*outputs);
+ cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
+ cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs);
free(cpu_state.input);
if(cpu_state.truth) free(cpu_state.truth);
}
void backward_detection_layer_gpu(detection_layer l, network_state state)
{
- int outputs = get_detection_layer_output_size(l);
-
- float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
- float *delta_cpu = calloc(l.batch*l.inputs, sizeof(float));
- float *truth_cpu = 0;
- if(state.truth){
- truth_cpu = calloc(l.batch*outputs, sizeof(float));
- cuda_pull_array(state.truth, truth_cpu, l.batch*outputs);
- }
- network_state cpu_state;
- cpu_state.train = state.train;
- cpu_state.input = in_cpu;
- cpu_state.truth = truth_cpu;
- cpu_state.delta = delta_cpu;
-
- cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
- cuda_pull_array(state.delta, delta_cpu, l.batch*l.inputs);
- cuda_pull_array(l.delta_gpu, l.delta, l.batch*outputs);
- backward_detection_layer(l, cpu_state);
- cuda_push_array(state.delta, delta_cpu, l.batch*l.inputs);
-
- if (truth_cpu) free(truth_cpu);
- free(in_cpu);
- free(delta_cpu);
+ axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
+ //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
}
#endif
diff --git a/src/detection_layer.h b/src/detection_layer.h
index 38d96ee..5e34ac7 100644
--- a/src/detection_layer.h
+++ b/src/detection_layer.h
@@ -1,16 +1,14 @@
-#ifndef DETECTION_LAYER_H
-#define DETECTION_LAYER_H
+#ifndef REGION_LAYER_H
+#define REGION_LAYER_H
#include "params.h"
#include "layer.h"
typedef layer detection_layer;
-detection_layer make_detection_layer(int batch, int inputs, int classes, int coords, int joint, int rescore, int background, int objectness);
+detection_layer make_detection_layer(int batch, int inputs, int n, int size, int classes, int coords, int rescore);
void forward_detection_layer(const detection_layer l, network_state state);
void backward_detection_layer(const detection_layer l, network_state state);
-int get_detection_layer_output_size(detection_layer l);
-int get_detection_layer_locations(detection_layer l);
#ifdef GPU
void forward_detection_layer_gpu(const detection_layer l, network_state state);
diff --git a/src/layer.h b/src/layer.h
index 2b136a0..0137c8e 100644
--- a/src/layer.h
+++ b/src/layer.h
@@ -15,7 +15,6 @@
ROUTE,
COST,
NORMALIZATION,
- REGION,
AVGPOOL
} LAYER_TYPE;
@@ -30,9 +29,6 @@
int batch_normalize;
int batch;
int forced;
- int object_logistic;
- int class_logistic;
- int coord_logistic;
int inputs;
int outputs;
int truths;
diff --git a/src/network.c b/src/network.c
index 063a1bb..9bcb264 100644
--- a/src/network.c
+++ b/src/network.c
@@ -11,7 +11,6 @@
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "detection_layer.h"
-#include "region_layer.h"
#include "normalization_layer.h"
#include "maxpool_layer.h"
#include "avgpool_layer.h"
@@ -72,8 +71,6 @@
return "softmax";
case DETECTION:
return "detection";
- case REGION:
- return "region";
case DROPOUT:
return "dropout";
case CROP:
@@ -119,8 +116,6 @@
forward_normalization_layer(l, state);
} else if(l.type == DETECTION){
forward_detection_layer(l, state);
- } else if(l.type == REGION){
- forward_region_layer(l, state);
} else if(l.type == CONNECTED){
forward_connected_layer(l, state);
} else if(l.type == CROP){
@@ -180,10 +175,6 @@
sum += net.layers[i].cost[0];
++count;
}
- if(net.layers[i].type == REGION){
- sum += net.layers[i].cost[0];
- ++count;
- }
}
return sum/count;
}
@@ -224,8 +215,6 @@
backward_dropout_layer(l, state);
} else if(l.type == DETECTION){
backward_detection_layer(l, state);
- } else if(l.type == REGION){
- backward_region_layer(l, state);
} else if(l.type == SOFTMAX){
if(i != 0) backward_softmax_layer(l, state);
} else if(l.type == CONNECTED){
diff --git a/src/network.h b/src/network.h
index 1caf838..0ad16ff 100644
--- a/src/network.h
+++ b/src/network.h
@@ -89,7 +89,6 @@
void set_batch_network(network *net, int b);
int get_network_input_size(network net);
float get_network_cost(network net);
-detection_layer get_network_detection_layer(network net);
int get_network_nuisance(network net);
int get_network_background(network net);
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index d2c8bf9..8561372 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -13,7 +13,6 @@
#include "crop_layer.h"
#include "connected_layer.h"
#include "detection_layer.h"
-#include "region_layer.h"
#include "convolutional_layer.h"
#include "deconvolutional_layer.h"
#include "maxpool_layer.h"
@@ -44,8 +43,6 @@
forward_deconvolutional_layer_gpu(l, state);
} else if(l.type == DETECTION){
forward_detection_layer_gpu(l, state);
- } else if(l.type == REGION){
- forward_region_layer_gpu(l, state);
} else if(l.type == CONNECTED){
forward_connected_layer_gpu(l, state);
} else if(l.type == CROP){
@@ -96,8 +93,6 @@
backward_dropout_layer_gpu(l, state);
} else if(l.type == DETECTION){
backward_detection_layer_gpu(l, state);
- } else if(l.type == REGION){
- backward_region_layer_gpu(l, state);
} else if(l.type == NORMALIZATION){
backward_normalization_layer_gpu(l, state);
} else if(l.type == SOFTMAX){
@@ -134,7 +129,7 @@
network_state state;
int x_size = get_network_input_size(net)*net.batch;
int y_size = get_network_output_size(net)*net.batch;
- if(net.layers[net.n-1].type == REGION) y_size = net.layers[net.n-1].truths*net.batch;
+ if(net.layers[net.n-1].type == DETECTION) y_size = net.layers[net.n-1].truths*net.batch;
if(!*net.input_gpu){
*net.input_gpu = cuda_make_array(x, x_size);
*net.truth_gpu = cuda_make_array(y, y_size);
diff --git a/src/old.c b/src/old.c
deleted file mode 100644
index 52b87fb..0000000
--- a/src/old.c
+++ /dev/null
@@ -1,607 +0,0 @@
-void save_network(network net, char *filename)
-{
- FILE *fp = fopen(filename, "w");
- if(!fp) file_error(filename);
- int i;
- for(i = 0; i < net.n; ++i)
- {
- if(net.types[i] == CONVOLUTIONAL)
- print_convolutional_cfg(fp, (convolutional_layer *)net.layers[i], net, i);
- else if(net.types[i] == DECONVOLUTIONAL)
- print_deconvolutional_cfg(fp, (deconvolutional_layer *)net.layers[i], net, i);
- else if(net.types[i] == CONNECTED)
- print_connected_cfg(fp, (connected_layer *)net.layers[i], net, i);
- else if(net.types[i] == CROP)
- print_crop_cfg(fp, (crop_layer *)net.layers[i], net, i);
- else if(net.types[i] == MAXPOOL)
- print_maxpool_cfg(fp, (maxpool_layer *)net.layers[i], net, i);
- else if(net.types[i] == DROPOUT)
- print_dropout_cfg(fp, (dropout_layer *)net.layers[i], net, i);
- else if(net.types[i] == SOFTMAX)
- print_softmax_cfg(fp, (softmax_layer *)net.layers[i], net, i);
- else if(net.types[i] == DETECTION)
- print_detection_cfg(fp, (detection_layer *)net.layers[i], net, i);
- else if(net.types[i] == COST)
- print_cost_cfg(fp, (cost_layer *)net.layers[i], net, i);
- }
- fclose(fp);
-}
-
-void print_convolutional_cfg(FILE *fp, convolutional_layer *l, network net, int count)
-{
-#ifdef GPU
- if(gpu_index >= 0) pull_convolutional_layer(*l);
-#endif
- int i;
- fprintf(fp, "[convolutional]\n");
- fprintf(fp, "filters=%d\n"
- "size=%d\n"
- "stride=%d\n"
- "pad=%d\n"
- "activation=%s\n",
- l->n, l->size, l->stride, l->pad,
- get_activation_string(l->activation));
- fprintf(fp, "biases=");
- for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
- fprintf(fp, "\n");
- fprintf(fp, "weights=");
- for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
- fprintf(fp, "\n\n");
-}
-
-void print_deconvolutional_cfg(FILE *fp, deconvolutional_layer *l, network net, int count)
-{
-#ifdef GPU
- if(gpu_index >= 0) pull_deconvolutional_layer(*l);
-#endif
- int i;
- fprintf(fp, "[deconvolutional]\n");
- fprintf(fp, "filters=%d\n"
- "size=%d\n"
- "stride=%d\n"
- "activation=%s\n",
- l->n, l->size, l->stride,
- get_activation_string(l->activation));
- fprintf(fp, "biases=");
- for(i = 0; i < l->n; ++i) fprintf(fp, "%g,", l->biases[i]);
- fprintf(fp, "\n");
- fprintf(fp, "weights=");
- for(i = 0; i < l->n*l->c*l->size*l->size; ++i) fprintf(fp, "%g,", l->filters[i]);
- fprintf(fp, "\n\n");
-}
-
-void print_dropout_cfg(FILE *fp, dropout_layer *l, network net, int count)
-{
- fprintf(fp, "[dropout]\n");
- fprintf(fp, "probability=%g\n\n", l->probability);
-}
-
-void print_connected_cfg(FILE *fp, connected_layer *l, network net, int count)
-{
-#ifdef GPU
- if(gpu_index >= 0) pull_connected_layer(*l);
-#endif
- int i;
- fprintf(fp, "[connected]\n");
- fprintf(fp, "output=%d\n"
- "activation=%s\n",
- l->outputs,
- get_activation_string(l->activation));
- fprintf(fp, "biases=");
- for(i = 0; i < l->outputs; ++i) fprintf(fp, "%g,", l->biases[i]);
- fprintf(fp, "\n");
- fprintf(fp, "weights=");
- for(i = 0; i < l->outputs*l->inputs; ++i) fprintf(fp, "%g,", l->weights[i]);
- fprintf(fp, "\n\n");
-}
-
-void print_crop_cfg(FILE *fp, crop_layer *l, network net, int count)
-{
- fprintf(fp, "[crop]\n");
- fprintf(fp, "crop_height=%d\ncrop_width=%d\nflip=%d\n\n", l->crop_height, l->crop_width, l->flip);
-}
-
-void print_maxpool_cfg(FILE *fp, maxpool_layer *l, network net, int count)
-{
- fprintf(fp, "[maxpool]\n");
- fprintf(fp, "size=%d\nstride=%d\n\n", l->size, l->stride);
-}
-
-void print_softmax_cfg(FILE *fp, softmax_layer *l, network net, int count)
-{
- fprintf(fp, "[softmax]\n");
- fprintf(fp, "\n");
-}
-
-void print_detection_cfg(FILE *fp, detection_layer *l, network net, int count)
-{
- fprintf(fp, "[detection]\n");
- fprintf(fp, "classes=%d\ncoords=%d\nrescore=%d\nnuisance=%d\n", l->classes, l->coords, l->rescore, l->nuisance);
- fprintf(fp, "\n");
-}
-
-void print_cost_cfg(FILE *fp, cost_layer *l, network net, int count)
-{
- fprintf(fp, "[cost]\ntype=%s\n", get_cost_string(l->type));
- fprintf(fp, "\n");
-}
-
-
-#ifndef NORMALIZATION_LAYER_H
-#define NORMALIZATION_LAYER_H
-
-#include "image.h"
-#include "params.h"
-
-typedef struct {
- int batch;
- int h,w,c;
- int size;
- float alpha;
- float beta;
- float kappa;
- float *delta;
- float *output;
- float *sums;
-} normalization_layer;
-
-image get_normalization_image(normalization_layer layer);
-normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa);
-void resize_normalization_layer(normalization_layer *layer, int h, int w);
-void forward_normalization_layer(const normalization_layer layer, network_state state);
-void backward_normalization_layer(const normalization_layer layer, network_state state);
-void visualize_normalization_layer(normalization_layer layer, char *window);
-
-#endif
-#include "normalization_layer.h"
-#include <stdio.h>
-
-image get_normalization_image(normalization_layer layer)
-{
- int h = layer.h;
- int w = layer.w;
- int c = layer.c;
- return float_to_image(w,h,c,layer.output);
-}
-
-image get_normalization_delta(normalization_layer layer)
-{
- int h = layer.h;
- int w = layer.w;
- int c = layer.c;
- return float_to_image(w,h,c,layer.delta);
-}
-
-normalization_layer *make_normalization_layer(int batch, int h, int w, int c, int size, float alpha, float beta, float kappa)
-{
- fprintf(stderr, "Local Response Normalization Layer: %d x %d x %d image, %d size\n", h,w,c,size);
- normalization_layer *layer = calloc(1, sizeof(normalization_layer));
- layer->batch = batch;
- layer->h = h;
- layer->w = w;
- layer->c = c;
- layer->kappa = kappa;
- layer->size = size;
- layer->alpha = alpha;
- layer->beta = beta;
- layer->output = calloc(h * w * c * batch, sizeof(float));
- layer->delta = calloc(h * w * c * batch, sizeof(float));
- layer->sums = calloc(h*w, sizeof(float));
- return layer;
-}
-
-void resize_normalization_layer(normalization_layer *layer, int h, int w)
-{
- layer->h = h;
- layer->w = w;
- layer->output = realloc(layer->output, h * w * layer->c * layer->batch * sizeof(float));
- layer->delta = realloc(layer->delta, h * w * layer->c * layer->batch * sizeof(float));
- layer->sums = realloc(layer->sums, h*w * sizeof(float));
-}
-
-void add_square_array(float *src, float *dest, int n)
-{
- int i;
- for(i = 0; i < n; ++i){
- dest[i] += src[i]*src[i];
- }
-}
-void sub_square_array(float *src, float *dest, int n)
-{
- int i;
- for(i = 0; i < n; ++i){
- dest[i] -= src[i]*src[i];
- }
-}
-
-void forward_normalization_layer(const normalization_layer layer, network_state state)
-{
- int i,j,k;
- memset(layer.sums, 0, layer.h*layer.w*sizeof(float));
- int imsize = layer.h*layer.w;
- for(j = 0; j < layer.size/2; ++j){
- if(j < layer.c) add_square_array(state.input+j*imsize, layer.sums, imsize);
- }
- for(k = 0; k < layer.c; ++k){
- int next = k+layer.size/2;
- int prev = k-layer.size/2-1;
- if(next < layer.c) add_square_array(state.input+next*imsize, layer.sums, imsize);
- if(prev > 0) sub_square_array(state.input+prev*imsize, layer.sums, imsize);
- for(i = 0; i < imsize; ++i){
- layer.output[k*imsize + i] = state.input[k*imsize+i] / pow(layer.kappa + layer.alpha * layer.sums[i], layer.beta);
- }
- }
-}
-
-void backward_normalization_layer(const normalization_layer layer, network_state state)
-{
- // TODO!
- // OR NOT TODO!!
-}
-
-void visualize_normalization_layer(normalization_layer layer, char *window)
-{
- image delta = get_normalization_image(layer);
- image dc = collapse_image_layers(delta, 1);
- char buff[256];
- sprintf(buff, "%s: Output", window);
- show_image(dc, buff);
- save_image(dc, buff);
- free_image(dc);
-}
-
-void test_load()
-{
- image dog = load_image("dog.jpg", 300, 400);
- show_image(dog, "Test Load");
- show_image_layers(dog, "Test Load");
-}
-
-void test_parser()
-{
- network net = parse_network_cfg("cfg/trained_imagenet.cfg");
- save_network(net, "cfg/trained_imagenet_smaller.cfg");
-}
-
-void test_init(char *cfgfile)
-{
- gpu_index = -1;
- network net = parse_network_cfg(cfgfile);
- set_batch_network(&net, 1);
- srand(2222222);
- int i = 0;
- char *filename = "data/test.jpg";
-
- image im = load_image_color(filename, 256, 256);
- //z_normalize_image(im);
- translate_image(im, -128);
- scale_image(im, 1/128.);
- float *X = im.data;
- forward_network(net, X, 0, 1);
- for(i = 0; i < net.n; ++i){
- if(net.types[i] == CONVOLUTIONAL){
- convolutional_layer layer = *(convolutional_layer *)net.layers[i];
- image output = get_convolutional_image(layer);
- int size = output.h*output.w*output.c;
- float v = variance_array(layer.output, size);
- float m = mean_array(layer.output, size);
- printf("%d: Convolutional, mean: %f, variance %f\n", i, m, v);
- }
- else if(net.types[i] == CONNECTED){
- connected_layer layer = *(connected_layer *)net.layers[i];
- int size = layer.outputs;
- float v = variance_array(layer.output, size);
- float m = mean_array(layer.output, size);
- printf("%d: Connected, mean: %f, variance %f\n", i, m, v);
- }
- }
- free_image(im);
-}
-void test_dog(char *cfgfile)
-{
- image im = load_image_color("data/dog.jpg", 256, 256);
- translate_image(im, -128);
- print_image(im);
- float *X = im.data;
- network net = parse_network_cfg(cfgfile);
- set_batch_network(&net, 1);
- network_predict(net, X);
- image crop = get_network_image_layer(net, 0);
- show_image(crop, "cropped");
- print_image(crop);
- show_image(im, "orig");
- float * inter = get_network_output(net);
- pm(1000, 1, inter);
- cvWaitKey(0);
-}
-
-void test_voc_segment(char *cfgfile, char *weightfile)
-{
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- set_batch_network(&net, 1);
- while(1){
- char filename[256];
- fgets(filename, 256, stdin);
- strtok(filename, "\n");
- image im = load_image_color(filename, 500, 500);
- //resize_network(net, im.h, im.w, im.c);
- translate_image(im, -128);
- scale_image(im, 1/128.);
- //float *predictions = network_predict(net, im.data);
- network_predict(net, im.data);
- free_image(im);
- image output = get_network_image_layer(net, net.n-2);
- show_image(output, "Segment Output");
- cvWaitKey(0);
- }
-}
-void test_visualize(char *filename)
-{
- network net = parse_network_cfg(filename);
- visualize_network(net);
- cvWaitKey(0);
-}
-
-void test_cifar10(char *cfgfile)
-{
- network net = parse_network_cfg(cfgfile);
- data test = load_cifar10_data("data/cifar10/test_batch.bin");
- clock_t start = clock(), end;
- float test_acc = network_accuracy_multi(net, test, 10);
- end = clock();
- printf("%f in %f Sec\n", test_acc, sec(end-start));
- //visualize_network(net);
- //cvWaitKey(0);
-}
-
-void train_cifar10(char *cfgfile)
-{
- srand(555555);
- srand(time(0));
- network net = parse_network_cfg(cfgfile);
- data test = load_cifar10_data("data/cifar10/test_batch.bin");
- int count = 0;
- int iters = 50000/net.batch;
- data train = load_all_cifar10();
- while(++count <= 10000){
- clock_t time = clock();
- float loss = train_network_sgd(net, train, iters);
-
- if(count%10 == 0){
- float test_acc = network_accuracy(net, test);
- printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,sec(clock()-time));
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/cifar10_%d.cfg", count);
- save_network(net, buff);
- }else{
- printf("%d: Loss: %f, Time: %lf seconds\n", count, loss, sec(clock()-time));
- }
-
- }
- free_data(train);
-}
-
-void compare_nist(char *p1,char *p2)
-{
- srand(222222);
- network n1 = parse_network_cfg(p1);
- network n2 = parse_network_cfg(p2);
- data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
- normalize_data_rows(test);
- compare_networks(n1, n2, test);
-}
-
-void test_nist(char *path)
-{
- srand(222222);
- network net = parse_network_cfg(path);
- data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
- normalize_data_rows(test);
- clock_t start = clock(), end;
- float test_acc = network_accuracy(net, test);
- end = clock();
- printf("Accuracy: %f, Time: %lf seconds\n", test_acc,(float)(end-start)/CLOCKS_PER_SEC);
-}
-
-void train_nist(char *cfgfile)
-{
- srand(222222);
- // srand(time(0));
- data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
- data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
- network net = parse_network_cfg(cfgfile);
- int count = 0;
- int iters = 6000/net.batch + 1;
- while(++count <= 100){
- clock_t start = clock(), end;
- normalize_data_rows(train);
- normalize_data_rows(test);
- float loss = train_network_sgd(net, train, iters);
- float test_acc = 0;
- if(count%1 == 0) test_acc = network_accuracy(net, test);
- end = clock();
- printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC);
- }
- free_data(train);
- free_data(test);
- char buff[256];
- sprintf(buff, "%s.trained", cfgfile);
- save_network(net, buff);
-}
-
-/*
- void train_nist_distributed(char *address)
- {
- srand(time(0));
- network net = parse_network_cfg("cfg/nist.client");
- data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
-//data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
-normalize_data_rows(train);
-//normalize_data_rows(test);
-int count = 0;
-int iters = 50000/net.batch;
-iters = 1000/net.batch + 1;
-while(++count <= 2000){
-clock_t start = clock(), end;
-float loss = train_network_sgd(net, train, iters);
-client_update(net, address);
-end = clock();
-//float test_acc = network_accuracy_gpu(net, test);
-//float test_acc = 0;
-printf("%d: Loss: %f, Time: %lf seconds\n", count, loss, (float)(end-start)/CLOCKS_PER_SEC);
-}
-}
- */
-
-void test_ensemble()
-{
- int i;
- srand(888888);
- data d = load_categorical_data_csv("mnist/mnist_train.csv", 0, 10);
- normalize_data_rows(d);
- data test = load_categorical_data_csv("mnist/mnist_test.csv", 0,10);
- normalize_data_rows(test);
- data train = d;
- // data *split = split_data(d, 1, 10);
- // data train = split[0];
- // data test = split[1];
- matrix prediction = make_matrix(test.y.rows, test.y.cols);
- int n = 30;
- for(i = 0; i < n; ++i){
- int count = 0;
- float lr = .0005;
- float momentum = .9;
- float decay = .01;
- network net = parse_network_cfg("nist.cfg");
- while(++count <= 15){
- float acc = train_network_sgd(net, train, train.X.rows);
- printf("Training Accuracy: %lf Learning Rate: %f Momentum: %f Decay: %f\n", acc, lr, momentum, decay );
- lr /= 2;
- }
- matrix partial = network_predict_data(net, test);
- float acc = matrix_topk_accuracy(test.y, partial,1);
- printf("Model Accuracy: %lf\n", acc);
- matrix_add_matrix(partial, prediction);
- acc = matrix_topk_accuracy(test.y, prediction,1);
- printf("Current Ensemble Accuracy: %lf\n", acc);
- free_matrix(partial);
- }
- float acc = matrix_topk_accuracy(test.y, prediction,1);
- printf("Full Ensemble Accuracy: %lf\n", acc);
-}
-
-void visualize_cat()
-{
- network net = parse_network_cfg("cfg/voc_imagenet.cfg");
- image im = load_image_color("data/cat.png", 0, 0);
- printf("Processing %dx%d image\n", im.h, im.w);
- resize_network(net, im.h, im.w, im.c);
- forward_network(net, im.data, 0, 0);
-
- visualize_network(net);
- cvWaitKey(0);
-}
-
-void test_correct_nist()
-{
- network net = parse_network_cfg("cfg/nist_conv.cfg");
- srand(222222);
- net = parse_network_cfg("cfg/nist_conv.cfg");
- data train = load_categorical_data_csv("data/mnist/mnist_train.csv", 0, 10);
- data test = load_categorical_data_csv("data/mnist/mnist_test.csv",0,10);
- normalize_data_rows(train);
- normalize_data_rows(test);
- int count = 0;
- int iters = 1000/net.batch;
-
- while(++count <= 5){
- clock_t start = clock(), end;
- float loss = train_network_sgd(net, train, iters);
- end = clock();
- float test_acc = network_accuracy(net, test);
- printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay);
- }
- save_network(net, "cfg/nist_gpu.cfg");
-
- gpu_index = -1;
- count = 0;
- srand(222222);
- net = parse_network_cfg("cfg/nist_conv.cfg");
- while(++count <= 5){
- clock_t start = clock(), end;
- float loss = train_network_sgd(net, train, iters);
- end = clock();
- float test_acc = network_accuracy(net, test);
- printf("%d: Loss: %f, Test Acc: %f, Time: %lf seconds, LR: %f, Momentum: %f, Decay: %f\n", count, loss, test_acc,(float)(end-start)/CLOCKS_PER_SEC, net.learning_rate, net.momentum, net.decay);
- }
- save_network(net, "cfg/nist_cpu.cfg");
-}
-
-void test_correct_alexnet()
-{
- char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
- list *plist = get_paths("/data/imagenet/cls.train.list");
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- int count = 0;
- network net;
-
- srand(222222);
- net = parse_network_cfg("cfg/net.cfg");
- int imgs = net.batch;
-
- count = 0;
- while(++count <= 5){
- time=clock();
- data train = load_data(paths, imgs, plist->size, labels, 1000, 256, 256);
- normalize_data_rows(train);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- printf("%d: %f, %lf seconds, %d images\n", count, loss, sec(clock()-time), imgs*net.batch);
- free_data(train);
- }
-
- gpu_index = -1;
- count = 0;
- srand(222222);
- net = parse_network_cfg("cfg/net.cfg");
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- while(++count <= 5){
- time=clock();
- data train = load_data(paths, imgs, plist->size, labels, 1000, 256,256);
- normalize_data_rows(train);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- printf("%d: %f, %lf seconds, %d images\n", count, loss, sec(clock()-time), imgs*net.batch);
- free_data(train);
- }
-}
-
-/*
- void run_server()
- {
- srand(time(0));
- network net = parse_network_cfg("cfg/net.cfg");
- set_batch_network(&net, 1);
- server_update(net);
- }
-
- void test_client()
- {
- network net = parse_network_cfg("cfg/alexnet.client");
- clock_t time=clock();
- client_update(net, "localhost");
- printf("1\n");
- client_update(net, "localhost");
- printf("2\n");
- client_update(net, "localhost");
- printf("3\n");
- printf("Transfered: %lf seconds\n", sec(clock()-time));
- }
- */
diff --git a/src/parser.c b/src/parser.c
index 254da5c..b095294 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -14,7 +14,6 @@
#include "softmax_layer.h"
#include "dropout_layer.h"
#include "detection_layer.h"
-#include "region_layer.h"
#include "avgpool_layer.h"
#include "route_layer.h"
#include "list.h"
@@ -38,7 +37,6 @@
int is_crop(section *s);
int is_cost(section *s);
int is_detection(section *s);
-int is_region(section *s);
int is_route(section *s);
list *read_cfg(char *filename);
@@ -168,35 +166,19 @@
int coords = option_find_int(options, "coords", 1);
int classes = option_find_int(options, "classes", 1);
int rescore = option_find_int(options, "rescore", 0);
- int joint = option_find_int(options, "joint", 0);
- int objectness = option_find_int(options, "objectness", 0);
- int background = option_find_int(options, "background", 0);
- detection_layer layer = make_detection_layer(params.batch, params.inputs, classes, coords, joint, rescore, background, objectness);
- return layer;
-}
-
-region_layer parse_region(list *options, size_params params)
-{
- int coords = option_find_int(options, "coords", 1);
- int classes = option_find_int(options, "classes", 1);
- int rescore = option_find_int(options, "rescore", 0);
int num = option_find_int(options, "num", 1);
int side = option_find_int(options, "side", 7);
- region_layer layer = make_region_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
+ detection_layer layer = make_detection_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
layer.softmax = option_find_int(options, "softmax", 0);
layer.sqrt = option_find_int(options, "sqrt", 0);
- layer.object_logistic = option_find_int(options, "object_logistic", 0);
- layer.class_logistic = option_find_int(options, "class_logistic", 0);
- layer.coord_logistic = option_find_int(options, "coord_logistic", 0);
-
layer.coord_scale = option_find_float(options, "coord_scale", 1);
layer.forced = option_find_int(options, "forced", 0);
layer.object_scale = option_find_float(options, "object_scale", 1);
layer.noobject_scale = option_find_float(options, "noobject_scale", 1);
layer.class_scale = option_find_float(options, "class_scale", 1);
- layer.jitter = option_find_float(options, "jitter", .1);
+ layer.jitter = option_find_float(options, "jitter", .2);
return layer;
}
@@ -430,8 +412,6 @@
l = parse_cost(options, params);
}else if(is_detection(s)){
l = parse_detection(options, params);
- }else if(is_region(s)){
- l = parse_region(options, params);
}else if(is_softmax(s)){
l = parse_softmax(options, params);
}else if(is_normalization(s)){
@@ -485,10 +465,6 @@
{
return (strcmp(s->type, "[detection]")==0);
}
-int is_region(section *s)
-{
- return (strcmp(s->type, "[region]")==0);
-}
int is_deconvolutional(section *s)
{
return (strcmp(s->type, "[deconv]")==0
diff --git a/src/region_layer.c b/src/region_layer.c
deleted file mode 100644
index 3fff22b..0000000
--- a/src/region_layer.c
+++ /dev/null
@@ -1,259 +0,0 @@
-#include "region_layer.h"
-#include "activations.h"
-#include "softmax_layer.h"
-#include "blas.h"
-#include "box.h"
-#include "cuda.h"
-#include "utils.h"
-#include <stdio.h>
-#include <assert.h>
-#include <string.h>
-#include <stdlib.h>
-
-region_layer make_region_layer(int batch, int inputs, int n, int side, int classes, int coords, int rescore)
-{
- region_layer l = {0};
- l.type = REGION;
-
- l.n = n;
- l.batch = batch;
- l.inputs = inputs;
- l.classes = classes;
- l.coords = coords;
- l.rescore = rescore;
- l.side = side;
- assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
- l.cost = calloc(1, sizeof(float));
- l.outputs = l.inputs;
- l.truths = l.side*l.side*(1+l.coords+l.classes);
- l.output = calloc(batch*l.outputs, sizeof(float));
- l.delta = calloc(batch*l.outputs, sizeof(float));
-#ifdef GPU
- l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
- l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
-#endif
-
- fprintf(stderr, "Region Layer\n");
- srand(0);
-
- return l;
-}
-
-void forward_region_layer(const region_layer l, network_state state)
-{
- int locations = l.side*l.side;
- int i,j;
- memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
- int b;
- if (l.softmax){
- for(b = 0; b < l.batch; ++b){
- int index = b*l.inputs;
- for (i = 0; i < locations; ++i) {
- int offset = i*l.classes;
- softmax_array(l.output + index + offset, l.classes,
- l.output + index + offset);
- }
- int offset = locations*l.classes;
- activate_array(l.output + index + offset, locations*l.n*(1+l.coords), LOGISTIC);
- }
- }
- if (l.object_logistic) {
- for(b = 0; b < l.batch; ++b){
- int index = b*l.inputs;
- int p_index = index + locations*l.classes;
- activate_array(l.output + p_index, locations*l.n, LOGISTIC);
- }
- }
-
- if (l.coord_logistic) {
- for(b = 0; b < l.batch; ++b){
- int index = b*l.inputs;
- int coord_index = index + locations*(l.classes + l.n);
- activate_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC);
- }
- }
-
- if (l.class_logistic) {
- for(b = 0; b < l.batch; ++b){
- int class_index = b*l.inputs;
- activate_array(l.output + class_index, locations*l.classes, LOGISTIC);
- }
- }
-
- if(state.train){
- float avg_iou = 0;
- float avg_cat = 0;
- float avg_allcat = 0;
- float avg_obj = 0;
- float avg_anyobj = 0;
- int count = 0;
- *(l.cost) = 0;
- int size = l.inputs * l.batch;
- memset(l.delta, 0, size * sizeof(float));
- for (b = 0; b < l.batch; ++b){
- int index = b*l.inputs;
- for (i = 0; i < locations; ++i) {
- int truth_index = (b*locations + i)*(1+l.coords+l.classes);
- int is_obj = state.truth[truth_index];
- for (j = 0; j < l.n; ++j) {
- int p_index = index + locations*l.classes + i*l.n + j;
- l.delta[p_index] = l.noobject_scale*(0 - l.output[p_index]);
- *(l.cost) += l.noobject_scale*pow(l.output[p_index], 2);
- avg_anyobj += l.output[p_index];
- }
-
- int best_index = -1;
- float best_iou = 0;
- float best_rmse = 20;
-
- if (!is_obj){
- continue;
- }
-
- int class_index = index + i*l.classes;
- for(j = 0; j < l.classes; ++j) {
- l.delta[class_index+j] = l.class_scale * (state.truth[truth_index+1+j] - l.output[class_index+j]);
- *(l.cost) += l.class_scale * pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
- if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
- avg_allcat += l.output[class_index+j];
- }
-
- box truth = float_to_box(state.truth + truth_index + 1 + l.classes);
- truth.x /= l.side;
- truth.y /= l.side;
-
- for(j = 0; j < l.n; ++j){
- int box_index = index + locations*(l.classes + l.n) + (i*l.n + j) * l.coords;
- box out = float_to_box(l.output + box_index);
- out.x /= l.side;
- out.y /= l.side;
-
- if (l.sqrt){
- out.w = out.w*out.w;
- out.h = out.h*out.h;
- }
-
- float iou = box_iou(out, truth);
- //iou = 0;
- float rmse = box_rmse(out, truth);
- if(best_iou > 0 || iou > 0){
- if(iou > best_iou){
- best_iou = iou;
- best_index = j;
- }
- }else{
- if(rmse < best_rmse){
- best_rmse = rmse;
- best_index = j;
- }
- }
- }
-
- if(l.forced){
- if(truth.w*truth.h < .1){
- best_index = 1;
- }else{
- best_index = 0;
- }
- }
-
- int box_index = index + locations*(l.classes + l.n) + (i*l.n + best_index) * l.coords;
- int tbox_index = truth_index + 1 + l.classes;
-
- box out = float_to_box(l.output + box_index);
- out.x /= l.side;
- out.y /= l.side;
- if (l.sqrt) {
- out.w = out.w*out.w;
- out.h = out.h*out.h;
- }
- float iou = box_iou(out, truth);
-
- //printf("%d", best_index);
- int p_index = index + locations*l.classes + i*l.n + best_index;
- *(l.cost) -= l.noobject_scale * pow(l.output[p_index], 2);
- *(l.cost) += l.object_scale * pow(1-l.output[p_index], 2);
- avg_obj += l.output[p_index];
- l.delta[p_index] = l.object_scale * (1.-l.output[p_index]);
-
- if(l.rescore){
- l.delta[p_index] = l.object_scale * (iou - l.output[p_index]);
- }
-
- l.delta[box_index+0] = l.coord_scale*(state.truth[tbox_index + 0] - l.output[box_index + 0]);
- l.delta[box_index+1] = l.coord_scale*(state.truth[tbox_index + 1] - l.output[box_index + 1]);
- l.delta[box_index+2] = l.coord_scale*(state.truth[tbox_index + 2] - l.output[box_index + 2]);
- l.delta[box_index+3] = l.coord_scale*(state.truth[tbox_index + 3] - l.output[box_index + 3]);
- if(l.sqrt){
- l.delta[box_index+2] = l.coord_scale*(sqrt(state.truth[tbox_index + 2]) - l.output[box_index + 2]);
- l.delta[box_index+3] = l.coord_scale*(sqrt(state.truth[tbox_index + 3]) - l.output[box_index + 3]);
- }
-
- *(l.cost) += pow(1-iou, 2);
- avg_iou += iou;
- ++count;
- }
- if(l.softmax){
- gradient_array(l.output + index + locations*l.classes, locations*l.n*(1+l.coords),
- LOGISTIC, l.delta + index + locations*l.classes);
- }
- if (l.object_logistic) {
- int p_index = index + locations*l.classes;
- gradient_array(l.output + p_index, locations*l.n, LOGISTIC, l.delta + p_index);
- }
-
- if (l.class_logistic) {
- int class_index = index;
- gradient_array(l.output + class_index, locations*l.classes, LOGISTIC, l.delta + class_index);
- }
-
- if (l.coord_logistic) {
- int coord_index = index + locations*(l.classes + l.n);
- gradient_array(l.output + coord_index, locations*l.n*l.coords, LOGISTIC, l.delta + coord_index);
- }
- //printf("\n");
- }
- printf("Region Avg IOU: %f, Pos Cat: %f, All Cat: %f, Pos Obj: %f, Any Obj: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_allcat/(count*l.classes), avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
- }
-}
-
-void backward_region_layer(const region_layer l, network_state state)
-{
- axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
-}
-
-#ifdef GPU
-
-void forward_region_layer_gpu(const region_layer l, network_state state)
-{
- if(!state.train){
- copy_ongpu(l.batch*l.inputs, state.input, 1, l.output_gpu, 1);
- return;
- }
-
- float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
- float *truth_cpu = 0;
- if(state.truth){
- int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
- truth_cpu = calloc(num_truth, sizeof(float));
- cuda_pull_array(state.truth, truth_cpu, num_truth);
- }
- cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
- network_state cpu_state;
- cpu_state.train = state.train;
- cpu_state.truth = truth_cpu;
- cpu_state.input = in_cpu;
- forward_region_layer(l, cpu_state);
- cuda_push_array(l.output_gpu, l.output, l.batch*l.outputs);
- cuda_push_array(l.delta_gpu, l.delta, l.batch*l.inputs);
- free(cpu_state.input);
- if(cpu_state.truth) free(cpu_state.truth);
-}
-
-void backward_region_layer_gpu(region_layer l, network_state state)
-{
- axpy_ongpu(l.batch*l.inputs, 1, l.delta_gpu, 1, state.delta, 1);
- //copy_ongpu(l.batch*l.inputs, l.delta_gpu, 1, state.delta, 1);
-}
-#endif
-
diff --git a/src/region_layer.h b/src/region_layer.h
deleted file mode 100644
index 95f8e91..0000000
--- a/src/region_layer.h
+++ /dev/null
@@ -1,18 +0,0 @@
-#ifndef REGION_LAYER_H
-#define REGION_LAYER_H
-
-#include "params.h"
-#include "layer.h"
-
-typedef layer region_layer;
-
-region_layer make_region_layer(int batch, int inputs, int n, int size, int classes, int coords, int rescore);
-void forward_region_layer(const region_layer l, network_state state);
-void backward_region_layer(const region_layer l, network_state state);
-
-#ifdef GPU
-void forward_region_layer_gpu(const region_layer l, network_state state);
-void backward_region_layer_gpu(region_layer l, network_state state);
-#endif
-
-#endif
diff --git a/src/swag.c b/src/swag.c
deleted file mode 100644
index 4dc6bf9..0000000
--- a/src/swag.c
+++ /dev/null
@@ -1,460 +0,0 @@
-#include "network.h"
-#include "region_layer.h"
-#include "detection_layer.h"
-#include "cost_layer.h"
-#include "utils.h"
-#include "parser.h"
-#include "box.h"
-
-#ifdef OPENCV
-#include "opencv2/highgui/highgui_c.h"
-#endif
-
-char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
-
-void draw_swag(image im, int num, float thresh, box *boxes, float **probs, char *label)
-{
- int classes = 20;
- int i;
-
- for(i = 0; i < num; ++i){
- int class = max_index(probs[i], classes);
- float prob = probs[i][class];
- if(prob > thresh){
- int width = pow(prob, 1./3.)*10 + 1;
- printf("%f %s\n", prob, voc_names[class]);
- float red = get_color(0,class,classes);
- float green = get_color(1,class,classes);
- float blue = get_color(2,class,classes);
- //red = green = blue = 0;
- box b = boxes[i];
-
- int left = (b.x-b.w/2.)*im.w;
- int right = (b.x+b.w/2.)*im.w;
- int top = (b.y-b.h/2.)*im.h;
- int bot = (b.y+b.h/2.)*im.h;
- draw_box_width(im, left, top, right, bot, width, red, green, blue);
- }
- }
- show_image(im, label);
-}
-
-void train_swag(char *cfgfile, char *weightfile)
-{
- //char *train_images = "/home/pjreddie/data/voc/person_detection/2010_person.txt";
- //char *train_images = "/home/pjreddie/data/people-art/train.txt";
- //char *train_images = "/home/pjreddie/data/voc/test/2012_trainval.txt";
- char *train_images = "/home/pjreddie/data/voc/test/train.txt";
- //char *train_images = "/home/pjreddie/data/voc/test/train_all.txt";
- //char *train_images = "/home/pjreddie/data/voc/test/2007_trainval.txt";
- char *backup_directory = "/home/pjreddie/backup/";
- srand(time(0));
- data_seed = time(0);
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- float avg_loss = -1;
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- int imgs = net.batch*net.subdivisions;
- int i = *net.seen/imgs;
- data train, buffer;
-
-
- layer l = net.layers[net.n - 1];
-
- int side = l.side;
- int classes = l.classes;
- float jitter = l.jitter;
-
- list *plist = get_paths(train_images);
- //int N = plist->size;
- char **paths = (char **)list_to_array(plist);
-
- load_args args = {0};
- args.w = net.w;
- args.h = net.h;
- args.paths = paths;
- args.n = imgs;
- args.m = plist->size;
- args.classes = classes;
- args.jitter = jitter;
- args.num_boxes = side;
- args.d = &buffer;
- args.type = REGION_DATA;
-
- pthread_t load_thread = load_data_in_thread(args);
- clock_t time;
- //while(i*imgs < N*120){
- while(get_current_batch(net) < net.max_batches){
- i += 1;
- time=clock();
- pthread_join(load_thread, 0);
- train = buffer;
- load_thread = load_data_in_thread(args);
-
- printf("Loaded: %lf seconds\n", sec(clock()-time));
-
- /*
- image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
- image copy = copy_image(im);
- draw_swag(copy, train.y.vals[113], 7, "truth");
- cvWaitKey(0);
- free_image(copy);
- */
-
- time=clock();
- float loss = train_network(net, train);
- if (avg_loss < 0) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
-
- printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
- if(i%1000==0){
- char buff[256];
- sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
- save_weights(net, buff);
- }
- free_data(train);
- }
- char buff[256];
- sprintf(buff, "%s/%s_final.weights", backup_directory, base);
- save_weights(net, buff);
-}
-
-void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness)
-{
- int i,j,n;
- //int per_cell = 5*num+classes;
- for (i = 0; i < side*side; ++i){
- int row = i / side;
- int col = i % side;
- for(n = 0; n < num; ++n){
- int index = i*num + n;
- int p_index = side*side*classes + i*num + n;
- float scale = predictions[p_index];
- int box_index = side*side*(classes + num) + (i*num + n)*4;
- boxes[index].x = (predictions[box_index + 0] + col) / side * w;
- boxes[index].y = (predictions[box_index + 1] + row) / side * h;
- boxes[index].w = pow(predictions[box_index + 2], (square?2:1)) * w;
- boxes[index].h = pow(predictions[box_index + 3], (square?2:1)) * h;
- for(j = 0; j < classes; ++j){
- int class_index = i*classes;
- float prob = scale*predictions[class_index+j];
- probs[index][j] = (prob > thresh) ? prob : 0;
- }
- if(only_objectness){
- probs[index][0] = scale;
- }
- }
- }
-}
-
-void print_swag_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
-{
- int i, j;
- for(i = 0; i < total; ++i){
- float xmin = boxes[i].x - boxes[i].w/2.;
- float xmax = boxes[i].x + boxes[i].w/2.;
- float ymin = boxes[i].y - boxes[i].h/2.;
- float ymax = boxes[i].y + boxes[i].h/2.;
-
- if (xmin < 0) xmin = 0;
- if (ymin < 0) ymin = 0;
- if (xmax > w) xmax = w;
- if (ymax > h) ymax = h;
-
- for(j = 0; j < classes; ++j){
- if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j],
- xmin, ymin, xmax, ymax);
- }
- }
-}
-
-void validate_swag(char *cfgfile, char *weightfile)
-{
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- set_batch_network(&net, 1);
- fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- srand(time(0));
-
- char *base = "results/comp4_det_test_";
- //base = "/home/pjreddie/comp4_det_test_";
- //list *plist = get_paths("/home/pjreddie/data/people-art/test.txt");
- //list *plist = get_paths("/home/pjreddie/data/cubist/test.txt");
- list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
- char **paths = (char **)list_to_array(plist);
-
- layer l = net.layers[net.n-1];
- int classes = l.classes;
- int square = l.sqrt;
- int side = l.side;
-
- int j;
- FILE **fps = calloc(classes, sizeof(FILE *));
- for(j = 0; j < classes; ++j){
- char buff[1024];
- snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
- fps[j] = fopen(buff, "w");
- }
- box *boxes = calloc(side*side*l.n, sizeof(box));
- float **probs = calloc(side*side*l.n, sizeof(float *));
- for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
-
- int m = plist->size;
- int i=0;
- int t;
-
- float thresh = .001;
- int nms = 1;
- float iou_thresh = .5;
-
- int nthreads = 2;
- image *val = calloc(nthreads, sizeof(image));
- image *val_resized = calloc(nthreads, sizeof(image));
- image *buf = calloc(nthreads, sizeof(image));
- image *buf_resized = calloc(nthreads, sizeof(image));
- pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
-
- load_args args = {0};
- args.w = net.w;
- args.h = net.h;
- args.type = IMAGE_DATA;
-
- for(t = 0; t < nthreads; ++t){
- args.path = paths[i+t];
- args.im = &buf[t];
- args.resized = &buf_resized[t];
- thr[t] = load_data_in_thread(args);
- }
- time_t start = time(0);
- for(i = nthreads; i < m+nthreads; i += nthreads){
- fprintf(stderr, "%d\n", i);
- for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
- pthread_join(thr[t], 0);
- val[t] = buf[t];
- val_resized[t] = buf_resized[t];
- }
- for(t = 0; t < nthreads && i+t < m; ++t){
- args.path = paths[i+t];
- args.im = &buf[t];
- args.resized = &buf_resized[t];
- thr[t] = load_data_in_thread(args);
- }
- for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
- char *path = paths[i+t-nthreads];
- char *id = basecfg(path);
- float *X = val_resized[t].data;
- float *predictions = network_predict(net, X);
- int w = val[t].w;
- int h = val[t].h;
- convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0);
- if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, iou_thresh);
- print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
- free(id);
- free_image(val[t]);
- free_image(val_resized[t]);
- }
- }
- fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
-}
-
-void validate_swag_recall(char *cfgfile, char *weightfile)
-{
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- set_batch_network(&net, 1);
- fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- srand(time(0));
-
- char *base = "results/comp4_det_test_";
- list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
- char **paths = (char **)list_to_array(plist);
-
- layer l = net.layers[net.n-1];
- int classes = l.classes;
- int square = l.sqrt;
- int side = l.side;
-
- int j, k;
- FILE **fps = calloc(classes, sizeof(FILE *));
- for(j = 0; j < classes; ++j){
- char buff[1024];
- snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
- fps[j] = fopen(buff, "w");
- }
- box *boxes = calloc(side*side*l.n, sizeof(box));
- float **probs = calloc(side*side*l.n, sizeof(float *));
- for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
-
- int m = plist->size;
- int i=0;
-
- float thresh = .001;
- int nms = 0;
- float iou_thresh = .5;
- float nms_thresh = .5;
-
- int total = 0;
- int correct = 0;
- int proposals = 0;
- float avg_iou = 0;
-
- for(i = 0; i < m; ++i){
- char *path = paths[i];
- image orig = load_image_color(path, 0, 0);
- image sized = resize_image(orig, net.w, net.h);
- char *id = basecfg(path);
- float *predictions = network_predict(net, sized.data);
- convert_swag_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
- if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
-
- char *labelpath = find_replace(path, "images", "labels");
- labelpath = find_replace(labelpath, "JPEGImages", "labels");
- labelpath = find_replace(labelpath, ".jpg", ".txt");
- labelpath = find_replace(labelpath, ".JPEG", ".txt");
-
- int num_labels = 0;
- box_label *truth = read_boxes(labelpath, &num_labels);
- for(k = 0; k < side*side*l.n; ++k){
- if(probs[k][0] > thresh){
- ++proposals;
- }
- }
- for (j = 0; j < num_labels; ++j) {
- ++total;
- box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
- float best_iou = 0;
- for(k = 0; k < side*side*l.n; ++k){
- float iou = box_iou(boxes[k], t);
- if(probs[k][0] > thresh && iou > best_iou){
- best_iou = iou;
- }
- }
- avg_iou += best_iou;
- if(best_iou > iou_thresh){
- ++correct;
- }
- }
-
- fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
- free(id);
- free_image(orig);
- free_image(sized);
- }
-}
-
-void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh)
-{
-
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- region_layer l = net.layers[net.n-1];
- set_batch_network(&net, 1);
- srand(2222222);
- clock_t time;
- char buff[256];
- char *input = buff;
- int j;
- float nms=.5;
- box *boxes = calloc(l.side*l.side*l.n, sizeof(box));
- float **probs = calloc(l.side*l.side*l.n, sizeof(float *));
- for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
- while(1){
- if(filename){
- strncpy(input, filename, 256);
- } else {
- printf("Enter Image Path: ");
- fflush(stdout);
- input = fgets(input, 256, stdin);
- if(!input) return;
- strtok(input, "\n");
- }
- image im = load_image_color(input,0,0);
- image sized = resize_image(im, net.w, net.h);
- float *X = sized.data;
- time=clock();
- float *predictions = network_predict(net, X);
- printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- convert_swag_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
- if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
- draw_swag(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
-
- show_image(sized, "resized");
- free_image(im);
- free_image(sized);
-#ifdef OPENCV
- cvWaitKey(0);
- cvDestroyAllWindows();
-#endif
- if (filename) break;
- }
-}
-
-
-/*
-#ifdef OPENCV
-image ipl_to_image(IplImage* src);
-#include "opencv2/highgui/highgui_c.h"
-#include "opencv2/imgproc/imgproc_c.h"
-
-void demo_swag(char *cfgfile, char *weightfile, float thresh)
-{
-network net = parse_network_cfg(cfgfile);
-if(weightfile){
-load_weights(&net, weightfile);
-}
-region_layer layer = net.layers[net.n-1];
-CvCapture *capture = cvCaptureFromCAM(-1);
-set_batch_network(&net, 1);
-srand(2222222);
-while(1){
-IplImage* frame = cvQueryFrame(capture);
-image im = ipl_to_image(frame);
-cvReleaseImage(&frame);
-rgbgr_image(im);
-
-image sized = resize_image(im, net.w, net.h);
-float *X = sized.data;
-float *predictions = network_predict(net, X);
-draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh);
-free_image(im);
-free_image(sized);
-cvWaitKey(10);
-}
-}
-#else
-void demo_swag(char *cfgfile, char *weightfile, float thresh){}
-#endif
- */
-
-void demo_swag(char *cfgfile, char *weightfile, float thresh);
-#ifndef GPU
-void demo_swag(char *cfgfile, char *weightfile, float thresh){}
-#endif
-
-void run_swag(int argc, char **argv)
-{
- float thresh = find_float_arg(argc, argv, "-thresh", .2);
- if(argc < 4){
- fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
- return;
- }
-
- char *cfg = argv[3];
- char *weights = (argc > 4) ? argv[4] : 0;
- char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh);
- else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
- else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights);
- else if(0==strcmp(argv[2], "recall")) validate_swag_recall(cfg, weights);
- else if(0==strcmp(argv[2], "demo")) demo_swag(cfg, weights, thresh);
-}
diff --git a/src/yolo.c b/src/yolo.c
index 4b241f3..77dae39 100644
--- a/src/yolo.c
+++ b/src/yolo.c
@@ -9,44 +9,36 @@
#include "opencv2/highgui/highgui_c.h"
#endif
-char *voc_class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
+char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
-void draw_yolo(image im, float *box, int side, int objectness, char *label, float thresh)
+void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label)
{
int classes = 20;
- int elems = 4+classes+objectness;
- int j;
- int r, c;
+ int i;
- for(r = 0; r < side; ++r){
- for(c = 0; c < side; ++c){
- j = (r*side + c) * elems;
- float scale = 1;
- if(objectness) scale = 1 - box[j++];
- int class = max_index(box+j, classes);
- if(scale * box[j+class] > thresh){
- int width = sqrt(scale*box[j+class])*5 + 1;
- printf("%f %s\n", scale * box[j+class], voc_class_names[class]);
- float red = get_color(0,class,classes);
- float green = get_color(1,class,classes);
- float blue = get_color(2,class,classes);
+ for(i = 0; i < num; ++i){
+ int class = max_index(probs[i], classes);
+ float prob = probs[i][class];
+ if(prob > thresh){
+ int width = pow(prob, 1./2.)*10;
+ printf("%f %s\n", prob, voc_names[class]);
+ float red = get_color(0,class,classes);
+ float green = get_color(1,class,classes);
+ float blue = get_color(2,class,classes);
+ //red = green = blue = 0;
+ box b = boxes[i];
- j += classes;
- float x = box[j+0];
- float y = box[j+1];
- x = (x+c)/side;
- y = (y+r)/side;
- float w = box[j+2]; //*maxwidth;
- float h = box[j+3]; //*maxheight;
- h = h*h;
- w = w*w;
+ int left = (b.x-b.w/2.)*im.w;
+ int right = (b.x+b.w/2.)*im.w;
+ int top = (b.y-b.h/2.)*im.h;
+ int bot = (b.y+b.h/2.)*im.h;
- int left = (x-w/2)*im.w;
- int right = (x+w/2)*im.w;
- int top = (y-h/2)*im.h;
- int bot = (y+h/2)*im.h;
- draw_box_width(im, left, top, right, bot, width, red, green, blue);
- }
+ if(left < 0) left = 0;
+ if(right > im.w-1) right = im.w-1;
+ if(top < 0) top = 0;
+ if(bot > im.h-1) bot = im.h-1;
+
+ draw_box_width(im, left, top, right, bot, width, red, green, blue);
}
}
show_image(im, label);
@@ -54,7 +46,13 @@
void train_yolo(char *cfgfile, char *weightfile)
{
+ //char *train_images = "/home/pjreddie/data/voc/person_detection/2010_person.txt";
+ //char *train_images = "/home/pjreddie/data/people-art/train.txt";
+ //char *train_images = "/home/pjreddie/data/voc/test/2012_trainval.txt";
+ //char *train_images = "/home/pjreddie/data/voc/test/2010_trainval.txt";
char *train_images = "/home/pjreddie/data/voc/test/train.txt";
+ //char *train_images = "/home/pjreddie/data/voc/test/train_all.txt";
+ //char *train_images = "/home/pjreddie/data/voc/test/2007_trainval.txt";
char *backup_directory = "/home/pjreddie/backup/";
srand(time(0));
data_seed = time(0);
@@ -65,27 +63,21 @@
if(weightfile){
load_weights(&net, weightfile);
}
- int imgs = 128;
+ printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ int imgs = net.batch*net.subdivisions;
int i = *net.seen/imgs;
-
- char **paths;
- list *plist = get_paths(train_images);
- int N = plist->size;
- paths = (char **)list_to_array(plist);
-
- if(i*imgs > N*80){
- net.layers[net.n-1].objectness = 0;
- net.layers[net.n-1].joint = 1;
- }
- if(i*imgs > N*120){
- net.layers[net.n-1].rescore = 1;
- }
data train, buffer;
- detection_layer layer = get_network_detection_layer(net);
- int classes = layer.classes;
- int background = layer.objectness;
- int side = sqrt(get_detection_layer_locations(layer));
+
+ layer l = net.layers[net.n - 1];
+
+ int side = l.side;
+ int classes = l.classes;
+ float jitter = l.jitter;
+
+ list *plist = get_paths(train_images);
+ //int N = plist->size;
+ char **paths = (char **)list_to_array(plist);
load_args args = {0};
args.w = net.w;
@@ -94,13 +86,14 @@
args.n = imgs;
args.m = plist->size;
args.classes = classes;
+ args.jitter = jitter;
args.num_boxes = side;
- args.background = background;
args.d = &buffer;
- args.type = DETECTION_DATA;
+ args.type = REGION_DATA;
pthread_t load_thread = load_data_in_thread(args);
clock_t time;
+ //while(i*imgs < N*120){
while(get_current_batch(net) < net.max_batches){
i += 1;
time=clock();
@@ -109,36 +102,21 @@
load_thread = load_data_in_thread(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
+
+ /*
+ image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
+ image copy = copy_image(im);
+ draw_yolo(copy, train.y.vals[113], 7, "truth");
+ cvWaitKey(0);
+ free_image(copy);
+ */
+
time=clock();
float loss = train_network(net, train);
if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %f rate, %d images, epoch: %f\n", get_current_batch(net), loss, avg_loss, sec(clock()-time), get_current_rate(net), *net.seen, (float)*net.seen/N);
-
- if((i-1)*imgs <= 80*N && i*imgs > N*80){
- fprintf(stderr, "Second stage done.\n");
- char buff[256];
- sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
- save_weights(net, buff);
- net.layers[net.n-1].joint = 1;
- net.layers[net.n-1].objectness = 0;
- background = 0;
-
- pthread_join(load_thread, 0);
- free_data(buffer);
- args.background = background;
- load_thread = load_data_in_thread(args);
- }
-
- if((i-1)*imgs <= 120*N && i*imgs > N*120){
- fprintf(stderr, "Third stage done.\n");
- char buff[256];
- sprintf(buff, "%s/%s_final.weights", backup_directory, base);
- net.layers[net.n-1].rescore = 1;
- save_weights(net, buff);
- }
-
+ printf("%d: %f, %f avg, %f rate, %lf seconds, %d images\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), i*imgs);
if(i%1000==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -147,36 +125,42 @@
free_data(train);
}
char buff[256];
- sprintf(buff, "%s/%s_rescore.weights", backup_directory, base);
+ sprintf(buff, "%s/%s_final.weights", backup_directory, base);
save_weights(net, buff);
}
-void convert_yolo_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
+void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness)
{
- int i,j;
- int per_box = 4+classes+(background || objectness);
- for (i = 0; i < num_boxes*num_boxes; ++i){
- float scale = 1;
- if(objectness) scale = 1-predictions[i*per_box];
- int offset = i*per_box+(background||objectness);
- for(j = 0; j < classes; ++j){
- float prob = scale*predictions[offset+j];
- probs[i][j] = (prob > thresh) ? prob : 0;
+ int i,j,n;
+ //int per_cell = 5*num+classes;
+ for (i = 0; i < side*side; ++i){
+ int row = i / side;
+ int col = i % side;
+ for(n = 0; n < num; ++n){
+ int index = i*num + n;
+ int p_index = side*side*classes + i*num + n;
+ float scale = predictions[p_index];
+ int box_index = side*side*(classes + num) + (i*num + n)*4;
+ boxes[index].x = (predictions[box_index + 0] + col) / side * w;
+ boxes[index].y = (predictions[box_index + 1] + row) / side * h;
+ boxes[index].w = pow(predictions[box_index + 2], (square?2:1)) * w;
+ boxes[index].h = pow(predictions[box_index + 3], (square?2:1)) * h;
+ for(j = 0; j < classes; ++j){
+ int class_index = i*classes;
+ float prob = scale*predictions[class_index+j];
+ probs[index][j] = (prob > thresh) ? prob : 0;
+ }
+ if(only_objectness){
+ probs[index][0] = scale;
+ }
}
- int row = i / num_boxes;
- int col = i % num_boxes;
- offset += classes;
- boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
- boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
- boxes[i].w = pow(predictions[offset + 2], 2) * w;
- boxes[i].h = pow(predictions[offset + 3], 2) * h;
}
}
-void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
+void print_yolo_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
{
int i, j;
- for(i = 0; i < num_boxes*num_boxes; ++i){
+ for(i = 0; i < total; ++i){
float xmin = boxes[i].x - boxes[i].w/2.;
float xmax = boxes[i].x + boxes[i].w/2.;
float ymin = boxes[i].y - boxes[i].h/2.;
@@ -201,29 +185,33 @@
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
- detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
char *base = "results/comp4_det_test_";
+ //base = "/home/pjreddie/comp4_det_test_";
+ //list *plist = get_paths("/home/pjreddie/data/people-art/test.txt");
+ //list *plist = get_paths("/home/pjreddie/data/cubist/test.txt");
+
list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
+ //list *plist = get_paths("/home/pjreddie/data/voc/test_2012.txt");
char **paths = (char **)list_to_array(plist);
- int classes = layer.classes;
- int objectness = layer.objectness;
- int background = layer.background;
- int num_boxes = sqrt(get_detection_layer_locations(layer));
+ layer l = net.layers[net.n-1];
+ int classes = l.classes;
+ int square = l.sqrt;
+ int side = l.side;
int j;
FILE **fps = calloc(classes, sizeof(FILE *));
for(j = 0; j < classes; ++j){
char buff[1024];
- snprintf(buff, 1024, "%s%s.txt", base, voc_class_names[j]);
+ snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
fps[j] = fopen(buff, "w");
}
- box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
- float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
- for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
+ box *boxes = calloc(side*side*l.n, sizeof(box));
+ float **probs = calloc(side*side*l.n, sizeof(float *));
+ for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
int m = plist->size;
int i=0;
@@ -233,7 +221,7 @@
int nms = 1;
float iou_thresh = .5;
- int nthreads = 8;
+ int nthreads = 2;
image *val = calloc(nthreads, sizeof(image));
image *val_resized = calloc(nthreads, sizeof(image));
image *buf = calloc(nthreads, sizeof(image));
@@ -272,9 +260,9 @@
float *predictions = network_predict(net, X);
int w = val[t].w;
int h = val[t].h;
- convert_yolo_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
- if (nms) do_nms(boxes, probs, num_boxes*num_boxes, classes, iou_thresh);
- print_yolo_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
+ convert_yolo_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes, 0);
+ if (nms) do_nms_sort(boxes, probs, side*side*l.n, classes, iou_thresh);
+ print_yolo_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
free(id);
free_image(val[t]);
free_image(val_resized[t]);
@@ -283,6 +271,93 @@
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
+void validate_yolo_recall(char *cfgfile, char *weightfile)
+{
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ set_batch_network(&net, 1);
+ fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ srand(time(0));
+
+ char *base = "results/comp4_det_test_";
+ list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
+ char **paths = (char **)list_to_array(plist);
+
+ layer l = net.layers[net.n-1];
+ int classes = l.classes;
+ int square = l.sqrt;
+ int side = l.side;
+
+ int j, k;
+ FILE **fps = calloc(classes, sizeof(FILE *));
+ for(j = 0; j < classes; ++j){
+ char buff[1024];
+ snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
+ fps[j] = fopen(buff, "w");
+ }
+ box *boxes = calloc(side*side*l.n, sizeof(box));
+ float **probs = calloc(side*side*l.n, sizeof(float *));
+ for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
+
+ int m = plist->size;
+ int i=0;
+
+ float thresh = .001;
+ int nms = 0;
+ float iou_thresh = .5;
+ float nms_thresh = .5;
+
+ int total = 0;
+ int correct = 0;
+ int proposals = 0;
+ float avg_iou = 0;
+
+ for(i = 0; i < m; ++i){
+ char *path = paths[i];
+ image orig = load_image_color(path, 0, 0);
+ image sized = resize_image(orig, net.w, net.h);
+ char *id = basecfg(path);
+ float *predictions = network_predict(net, sized.data);
+ convert_yolo_detections(predictions, classes, l.n, square, side, 1, 1, thresh, probs, boxes, 1);
+ if (nms) do_nms(boxes, probs, side*side*l.n, 1, nms_thresh);
+
+ char *labelpath = find_replace(path, "images", "labels");
+ labelpath = find_replace(labelpath, "JPEGImages", "labels");
+ labelpath = find_replace(labelpath, ".jpg", ".txt");
+ labelpath = find_replace(labelpath, ".JPEG", ".txt");
+
+ int num_labels = 0;
+ box_label *truth = read_boxes(labelpath, &num_labels);
+ for(k = 0; k < side*side*l.n; ++k){
+ if(probs[k][0] > thresh){
+ ++proposals;
+ }
+ }
+ for (j = 0; j < num_labels; ++j) {
+ ++total;
+ box t = {truth[j].x, truth[j].y, truth[j].w, truth[j].h};
+ float best_iou = 0;
+ for(k = 0; k < side*side*l.n; ++k){
+ float iou = box_iou(boxes[k], t);
+ if(probs[k][0] > thresh && iou > best_iou){
+ best_iou = iou;
+ }
+ }
+ avg_iou += best_iou;
+ if(best_iou > iou_thresh){
+ ++correct;
+ }
+ }
+
+ fprintf(stderr, "%5d %5d %5d\tRPs/Img: %.2f\tIOU: %.2f%%\tRecall:%.2f%%\n", i, correct, total, (float)proposals/(i+1), avg_iou*100/total, 100.*correct/total);
+ free(id);
+ free_image(orig);
+ free_image(sized);
+ }
+}
+
void test_yolo(char *cfgfile, char *weightfile, char *filename, float thresh)
{
@@ -290,12 +365,18 @@
if(weightfile){
load_weights(&net, weightfile);
}
- detection_layer layer = get_network_detection_layer(net);
+ detection_layer l = net.layers[net.n-1];
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
char buff[256];
char *input = buff;
+ int j;
+ float nms=.5;
+ printf("%d %d %d", l.side, l.n, l.classes);
+ box *boxes = calloc(l.side*l.side*l.n, sizeof(box));
+ float **probs = calloc(l.side*l.side*l.n, sizeof(float *));
+ for(j = 0; j < l.side*l.side*l.n; ++j) probs[j] = calloc(l.classes, sizeof(float *));
while(1){
if(filename){
strncpy(input, filename, 256);
@@ -312,7 +393,11 @@
time=clock();
float *predictions = network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- draw_yolo(im, predictions, 7, layer.objectness, "predictions", thresh);
+ convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
+ if (nms) do_nms_sort(boxes, probs, l.side*l.side*l.n, l.classes, nms);
+ draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
+
+ show_image(sized, "resized");
free_image(im);
free_image(sized);
#ifdef OPENCV
@@ -323,6 +408,47 @@
}
}
+/*
+#ifdef OPENCV
+image ipl_to_image(IplImage* src);
+#include "opencv2/highgui/highgui_c.h"
+#include "opencv2/imgproc/imgproc_c.h"
+
+void demo_swag(char *cfgfile, char *weightfile, float thresh)
+{
+network net = parse_network_cfg(cfgfile);
+if(weightfile){
+load_weights(&net, weightfile);
+}
+detection_layer layer = net.layers[net.n-1];
+CvCapture *capture = cvCaptureFromCAM(-1);
+set_batch_network(&net, 1);
+srand(2222222);
+while(1){
+IplImage* frame = cvQueryFrame(capture);
+image im = ipl_to_image(frame);
+cvReleaseImage(&frame);
+rgbgr_image(im);
+
+image sized = resize_image(im, net.w, net.h);
+float *X = sized.data;
+float *predictions = network_predict(net, X);
+draw_swag(im, predictions, layer.side, layer.n, "predictions", thresh);
+free_image(im);
+free_image(sized);
+cvWaitKey(10);
+}
+}
+#else
+void demo_swag(char *cfgfile, char *weightfile, float thresh){}
+#endif
+ */
+
+void demo_yolo(char *cfgfile, char *weightfile, float thresh);
+#ifndef GPU
+void demo_yolo(char *cfgfile, char *weightfile, float thresh){}
+#endif
+
void run_yolo(int argc, char **argv)
{
float thresh = find_float_arg(argc, argv, "-thresh", .2);
@@ -337,4 +463,6 @@
if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_yolo(cfg, weights);
+ else if(0==strcmp(argv[2], "recall")) validate_yolo_recall(cfg, weights);
+ else if(0==strcmp(argv[2], "demo")) demo_yolo(cfg, weights, thresh);
}
diff --git a/src/swag_kernels.cu b/src/yolo_kernels.cu
similarity index 77%
rename from src/swag_kernels.cu
rename to src/yolo_kernels.cu
index 5cba15c..f02b7a2 100644
--- a/src/swag_kernels.cu
+++ b/src/yolo_kernels.cu
@@ -1,6 +1,5 @@
extern "C" {
#include "network.h"
-#include "region_layer.h"
#include "detection_layer.h"
#include "cost_layer.h"
#include "utils.h"
@@ -13,16 +12,16 @@
#include "opencv2/highgui/highgui.hpp"
#include "opencv2/imgproc/imgproc.hpp"
extern "C" image ipl_to_image(IplImage* src);
-extern "C" void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
-extern "C" void draw_swag(image im, int num, float thresh, box *boxes, float **probs, char *label);
+extern "C" void convert_yolo_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes, int only_objectness);
+extern "C" void draw_yolo(image im, int num, float thresh, box *boxes, float **probs, char *label);
-extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh)
+extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
- region_layer l = net.layers[net.n-1];
+ detection_layer l = net.layers[net.n-1];
cv::VideoCapture cap(0);
set_batch_network(&net, 1);
@@ -43,12 +42,12 @@
image sized = resize_image(im, net.w, net.h);
float *X = sized.data;
float *predictions = network_predict(net, X);
- convert_swag_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
+ convert_yolo_detections(predictions, l.classes, l.n, l.sqrt, l.side, 1, 1, thresh, probs, boxes, 0);
if (nms > 0) do_nms(boxes, probs, l.side*l.side*l.n, l.classes, nms);
printf("\033[2J");
printf("\033[1;1H");
printf("\nObjects:\n\n");
- draw_swag(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
+ draw_yolo(im, l.side*l.side*l.n, thresh, boxes, probs, "predictions");
free_image(im);
free_image(sized);
@@ -56,6 +55,6 @@
}
}
#else
-extern "C" void demo_swag(char *cfgfile, char *weightfile, float thresh){}
+extern "C" void demo_yolo(char *cfgfile, char *weightfile, float thresh){}
#endif
--
Gitblit v1.10.0