From c53e03348c65462bcba33f6352087dd6b9268e8f Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Wed, 16 Sep 2015 21:12:10 +0000
Subject: [PATCH] yolo working w/ regions
---
src/network.c | 2
src/swag.c | 145 ++++++++++-------------
src/coco.c | 14 --
Makefile | 2
src/network_kernels.cu | 1
src/parser.c | 4
src/data.c | 4
cfg/darknet.cfg | 10 +
src/region_layer.c | 143 ++++++++++++-----------
src/darknet.c | 6
src/layer.h | 3
src/compare.c | 4
12 files changed, 160 insertions(+), 178 deletions(-)
diff --git a/Makefile b/Makefile
index 581b6d7..40bfcec 100644
--- a/Makefile
+++ b/Makefile
@@ -34,7 +34,7 @@
LDFLAGS+= -L/usr/local/cuda/lib64 -lcuda -lcudart -lcublas -lcurand
endif
-OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o yoloplus.o
+OBJ=gemm.o utils.o cuda.o deconvolutional_layer.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o imagenet.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o region_layer.o layer.o compare.o swag.o
ifeq ($(GPU), 1)
OBJ+=convolutional_kernels.o deconvolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o softmax_layer_kernels.o network_kernels.o avgpool_layer_kernels.o
endif
diff --git a/cfg/darknet.cfg b/cfg/darknet.cfg
index eb1310a..64aab1e 100644
--- a/cfg/darknet.cfg
+++ b/cfg/darknet.cfg
@@ -1,15 +1,17 @@
[net]
-batch=128
+batch=256
subdivisions=1
height=256
width=256
channels=3
momentum=0.9
decay=0.0005
+
learning_rate=0.01
-policy=poly
-power=.5
-max_batches=600000
+policy=step
+scale=.1
+step=100000
+max_batches=400000
[crop]
crop_height=224
diff --git a/src/coco.c b/src/coco.c
index 87f3dca..234f342 100644
--- a/src/coco.c
+++ b/src/coco.c
@@ -111,20 +111,6 @@
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
- if((i-1)*imgs <= N && i*imgs > N){
- fprintf(stderr, "First stage done\n");
- net.learning_rate *= 10;
- char buff[256];
- sprintf(buff, "%s/%s_first_stage.weights", backup_directory, base);
- save_weights(net, buff);
- }
-
- if((i-1)*imgs <= 80*N && i*imgs > N*80){
- fprintf(stderr, "Second stage done.\n");
- char buff[256];
- sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
- save_weights(net, buff);
- }
if(i%1000==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
diff --git a/src/compare.c b/src/compare.c
index 9b6d6bf..0408f80 100644
--- a/src/compare.c
+++ b/src/compare.c
@@ -175,8 +175,8 @@
image im1 = load_image_color(box1.filename, net.w, net.h);
image im2 = load_image_color(box2.filename, net.w, net.h);
float *X = calloc(net.w*net.h*net.c, sizeof(float));
- memcpy(X, im1.data, im1.w*im1.h*im1.c);
- memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c);
+ memcpy(X, im1.data, im1.w*im1.h*im1.c*sizeof(float));
+ memcpy(X+im1.w*im1.h*im1.c, im2.data, im2.w*im2.h*im2.c*sizeof(float));
float *predictions = network_predict(net, X);
free_image(im1);
diff --git a/src/darknet.c b/src/darknet.c
index 833f89e..9632f91 100644
--- a/src/darknet.c
+++ b/src/darknet.c
@@ -13,7 +13,7 @@
extern void run_imagenet(int argc, char **argv);
extern void run_yolo(int argc, char **argv);
-extern void run_yoloplus(int argc, char **argv);
+extern void run_swag(int argc, char **argv);
extern void run_coco(int argc, char **argv);
extern void run_writing(int argc, char **argv);
extern void run_captcha(int argc, char **argv);
@@ -179,8 +179,8 @@
average(argc, argv);
} else if (0 == strcmp(argv[1], "yolo")){
run_yolo(argc, argv);
- } else if (0 == strcmp(argv[1], "yoloplus")){
- run_yoloplus(argc, argv);
+ } else if (0 == strcmp(argv[1], "swag")){
+ run_swag(argc, argv);
} else if (0 == strcmp(argv[1], "coco")){
run_coco(argc, argv);
} else if (0 == strcmp(argv[1], "compare")){
diff --git a/src/data.c b/src/data.c
index 003338e..17772c1 100644
--- a/src/data.c
+++ b/src/data.c
@@ -176,8 +176,10 @@
int index = (col+row*num_boxes)*(5+classes);
if (truth[index]) continue;
truth[index++] = 1;
- if (classes) truth[index+id] = 1;
+
+ if (id < classes) truth[index+id] = 1;
index += classes;
+
truth[index++] = x;
truth[index++] = y;
truth[index++] = w;
diff --git a/src/layer.h b/src/layer.h
index 77d7f08..1eb7351 100644
--- a/src/layer.h
+++ b/src/layer.h
@@ -30,6 +30,7 @@
int batch;
int inputs;
int outputs;
+ int truths;
int h,w,c;
int out_h, out_w, out_c;
int n;
@@ -40,10 +41,12 @@
int pad;
int crop_width;
int crop_height;
+ int sqrt;
int flip;
float angle;
float saturation;
float exposure;
+ int softmax;
int classes;
int coords;
int background;
diff --git a/src/network.c b/src/network.c
index af4861a..80ee291 100644
--- a/src/network.c
+++ b/src/network.c
@@ -48,7 +48,7 @@
case POLY:
return net.learning_rate * pow(1 - (float)batch_num / net.max_batches, net.power);
case SIG:
- return net.learning_rate * (1/(1+exp(net.gamma*(batch_num - net.step))));
+ return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
default:
fprintf(stderr, "Policy is weird!\n");
return net.learning_rate;
diff --git a/src/network_kernels.cu b/src/network_kernels.cu
index 1f0a654..cfc6e83 100644
--- a/src/network_kernels.cu
+++ b/src/network_kernels.cu
@@ -134,6 +134,7 @@
network_state state;
int x_size = get_network_input_size(net)*net.batch;
int y_size = get_network_output_size(net)*net.batch;
+ if(net.layers[net.n-1].type == REGION) y_size = net.layers[net.n-1].truths*net.batch;
if(!*net.input_gpu){
*net.input_gpu = cuda_make_array(x, x_size);
*net.truth_gpu = cuda_make_array(y, y_size);
diff --git a/src/parser.c b/src/parser.c
index 94dc0fa..53e8461 100644
--- a/src/parser.c
+++ b/src/parser.c
@@ -182,6 +182,10 @@
int num = option_find_int(options, "num", 1);
int side = option_find_int(options, "side", 7);
region_layer layer = make_region_layer(params.batch, params.inputs, num, side, classes, coords, rescore);
+ int softmax = option_find_int(options, "softmax", 0);
+ int sqrt = option_find_int(options, "sqrt", 0);
+ layer.softmax = softmax;
+ layer.sqrt = sqrt;
return layer;
}
diff --git a/src/region_layer.c b/src/region_layer.c
index dcdcfad..d65c1a8 100644
--- a/src/region_layer.c
+++ b/src/region_layer.c
@@ -14,7 +14,7 @@
{
region_layer l = {0};
l.type = REGION;
-
+
l.n = n;
l.batch = batch;
l.inputs = inputs;
@@ -22,15 +22,15 @@
l.coords = coords;
l.rescore = rescore;
l.side = side;
- assert(side*side*l.coords*l.n == inputs);
+ assert(side*side*((1 + l.coords)*l.n + l.classes) == inputs);
l.cost = calloc(1, sizeof(float));
- int outputs = l.n*5*side*side;
- l.outputs = outputs;
- l.output = calloc(batch*outputs, sizeof(float));
- l.delta = calloc(batch*inputs, sizeof(float));
- #ifdef GPU
- l.output_gpu = cuda_make_array(l.output, batch*outputs);
- l.delta_gpu = cuda_make_array(l.delta, batch*inputs);
+ l.outputs = l.inputs;
+ l.truths = l.side*l.side*(1+l.coords+l.classes);
+ l.output = calloc(batch*l.outputs, sizeof(float));
+ l.delta = calloc(batch*l.outputs, sizeof(float));
+#ifdef GPU
+ l.output_gpu = cuda_make_array(l.output, batch*l.outputs);
+ l.delta_gpu = cuda_make_array(l.delta, batch*l.outputs);
#endif
fprintf(stderr, "Region Layer\n");
@@ -43,64 +43,69 @@
{
int locations = l.side*l.side;
int i,j;
+ memcpy(l.output, state.input, l.outputs*l.batch*sizeof(float));
for(i = 0; i < l.batch*locations; ++i){
- for(j = 0; j < l.n; ++j){
- int in_index = i*l.n*l.coords + j*l.coords;
- int out_index = i*l.n*5 + j*5;
-
- float prob = state.input[in_index+0];
- float x = state.input[in_index+1];
- float y = state.input[in_index+2];
- float w = state.input[in_index+3];
- float h = state.input[in_index+4];
- /*
- float min_w = state.input[in_index+5];
- float max_w = state.input[in_index+6];
- float min_h = state.input[in_index+7];
- float max_h = state.input[in_index+8];
- */
-
- l.output[out_index+0] = prob;
- l.output[out_index+1] = x;
- l.output[out_index+2] = y;
- l.output[out_index+3] = w;
- l.output[out_index+4] = h;
-
+ int index = i*((1+l.coords)*l.n + l.classes);
+ if(l.softmax){
+ activate_array(l.output + index, l.n*(1+l.coords), LOGISTIC);
+ int offset = l.n*(1+l.coords);
+ softmax_array(l.output + index + offset, l.classes,
+ l.output + index + offset);
}
}
if(state.train){
float avg_iou = 0;
+ float avg_cat = 0;
+ float avg_obj = 0;
+ float avg_anyobj = 0;
int count = 0;
*(l.cost) = 0;
int size = l.inputs * l.batch;
memset(l.delta, 0, size * sizeof(float));
for (i = 0; i < l.batch*locations; ++i) {
-
+ int index = i*((1+l.coords)*l.n + l.classes);
for(j = 0; j < l.n; ++j){
- int in_index = i*l.n*l.coords + j*l.coords;
- l.delta[in_index+0] = .1*(0-state.input[in_index+0]);
+ int prob_index = index + j*(1 + l.coords);
+ l.delta[prob_index] = (1./l.n)*(0-l.output[prob_index]);
+ if(l.softmax){
+ l.delta[prob_index] = 1./(l.n*l.side)*(0-l.output[prob_index]);
+ }
+ *(l.cost) += (1./l.n)*pow(l.output[prob_index], 2);
+ //printf("%f\n", l.output[prob_index]);
+ avg_anyobj += l.output[prob_index];
}
- int truth_index = i*5;
+ int truth_index = i*(1 + l.coords + l.classes);
int best_index = -1;
float best_iou = 0;
float best_rmse = 4;
int bg = !state.truth[truth_index];
- if(bg) continue;
+ if(bg) {
+ continue;
+ }
- box truth = {state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3], state.truth[truth_index+4]};
+ int class_index = index + l.n*(1+l.coords);
+ for(j = 0; j < l.classes; ++j) {
+ l.delta[class_index+j] = state.truth[truth_index+1+j] - l.output[class_index+j];
+ *(l.cost) += pow(state.truth[truth_index+1+j] - l.output[class_index+j], 2);
+ if(state.truth[truth_index + 1 + j]) avg_cat += l.output[class_index+j];
+ }
+ truth_index += l.classes + 1;
+ box truth = {state.truth[truth_index+0], state.truth[truth_index+1], state.truth[truth_index+2], state.truth[truth_index+3]};
truth.x /= l.side;
truth.y /= l.side;
for(j = 0; j < l.n; ++j){
- int out_index = i*l.n*5 + j*5;
+ int out_index = index + j*(1+l.coords);
box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]};
- //printf("\n%f %f %f %f %f\n", l.output[out_index+0], out.x, out.y, out.w, out.h);
-
out.x /= l.side;
out.y /= l.side;
+ if (l.sqrt){
+ out.w = out.w*out.w;
+ out.h = out.h*out.h;
+ }
float iou = box_iou(out, truth);
float rmse = box_rmse(out, truth);
@@ -116,46 +121,41 @@
}
}
}
- printf("%d", best_index);
- //int out_index = i*l.n*5 + best_index*5;
- //box out = {l.output[out_index+1], l.output[out_index+2], l.output[out_index+3], l.output[out_index+4]};
- int in_index = i*l.n*l.coords + best_index*l.coords;
+ //printf("%d", best_index);
+ int in_index = index + best_index*(1+l.coords);
+ *(l.cost) -= pow(l.output[in_index], 2);
+ *(l.cost) += pow(1-l.output[in_index], 2);
+ avg_obj += l.output[in_index];
+ l.delta[in_index+0] = (1.-l.output[in_index]);
+ if(l.softmax){
+ l.delta[in_index+0] = 5*(1.-l.output[in_index]);
+ }
+ //printf("%f\n", l.output[in_index]);
- l.delta[in_index+0] = (1-state.input[in_index+0]);
- l.delta[in_index+1] = state.truth[truth_index+1] - state.input[in_index+1];
- l.delta[in_index+2] = state.truth[truth_index+2] - state.input[in_index+2];
- l.delta[in_index+3] = state.truth[truth_index+3] - state.input[in_index+3];
- l.delta[in_index+4] = state.truth[truth_index+4] - state.input[in_index+4];
- /*
- l.delta[in_index+5] = 0 - state.input[in_index+5];
- l.delta[in_index+6] = 1 - state.input[in_index+6];
- l.delta[in_index+7] = 0 - state.input[in_index+7];
- l.delta[in_index+8] = 1 - state.input[in_index+8];
- */
+ l.delta[in_index+1] = 5*(state.truth[truth_index+0] - l.output[in_index+1]);
+ l.delta[in_index+2] = 5*(state.truth[truth_index+1] - l.output[in_index+2]);
+ if(l.sqrt){
+ l.delta[in_index+3] = 5*(sqrt(state.truth[truth_index+2]) - l.output[in_index+3]);
+ l.delta[in_index+4] = 5*(sqrt(state.truth[truth_index+3]) - l.output[in_index+4]);
+ }else{
+ l.delta[in_index+3] = 5*(state.truth[truth_index+2] - l.output[in_index+3]);
+ l.delta[in_index+4] = 5*(state.truth[truth_index+3] - l.output[in_index+4]);
+ }
- /*
- float x = state.input[in_index+1];
- float y = state.input[in_index+2];
- float w = state.input[in_index+3];
- float h = state.input[in_index+4];
- float min_w = state.input[in_index+5];
- float max_w = state.input[in_index+6];
- float min_h = state.input[in_index+7];
- float max_h = state.input[in_index+8];
- */
-
-
+ *(l.cost) += pow(1-best_iou, 2);
avg_iou += best_iou;
++count;
+ if(l.softmax){
+ gradient_array(l.output + index, l.n*(1+l.coords), LOGISTIC, l.delta + index);
+ }
}
- printf("\nAvg IOU: %f %d\n", avg_iou/count, count);
+ printf("Avg IOU: %f, Avg Cat Pred: %f, Avg Obj: %f, Avg Any: %f, count: %d\n", avg_iou/count, avg_cat/count, avg_obj/count, avg_anyobj/(l.batch*locations*l.n), count);
}
}
void backward_region_layer(const region_layer l, network_state state)
{
axpy_cpu(l.batch*l.inputs, 1, l.delta, 1, state.delta, 1);
- //copy_cpu(l.batch*l.inputs, l.delta, 1, state.delta, 1);
}
#ifdef GPU
@@ -165,8 +165,9 @@
float *in_cpu = calloc(l.batch*l.inputs, sizeof(float));
float *truth_cpu = 0;
if(state.truth){
- truth_cpu = calloc(l.batch*l.outputs, sizeof(float));
- cuda_pull_array(state.truth, truth_cpu, l.batch*l.outputs);
+ int num_truth = l.batch*l.side*l.side*(1+l.coords+l.classes);
+ truth_cpu = calloc(num_truth, sizeof(float));
+ cuda_pull_array(state.truth, truth_cpu, num_truth);
}
cuda_pull_array(state.input, in_cpu, l.batch*l.inputs);
network_state cpu_state;
diff --git a/src/yoloplus.c b/src/swag.c
similarity index 64%
rename from src/yoloplus.c
rename to src/swag.c
index dcae7bc..4dcf36b 100644
--- a/src/yoloplus.c
+++ b/src/swag.c
@@ -11,7 +11,7 @@
char *voc_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
-void draw_yoloplus(image im, float *box, int side, int objectness, char *label, float thresh)
+void draw_swag(image im, float *box, int side, int objectness, char *label, float thresh)
{
int classes = 20;
int elems = 4+classes+objectness;
@@ -52,7 +52,7 @@
show_image(im, label);
}
-void train_yoloplus(char *cfgfile, char *weightfile)
+void train_swag(char *cfgfile, char *weightfile)
{
char *train_images = "/home/pjreddie/data/voc/test/train.txt";
char *backup_directory = "/home/pjreddie/backup/";
@@ -65,23 +65,20 @@
if(weightfile){
load_weights(&net, weightfile);
}
- detection_layer layer = get_network_detection_layer(net);
- int imgs = 128;
+ printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ int imgs = net.batch*net.subdivisions;
int i = *net.seen/imgs;
-
- char **paths;
- list *plist = get_paths(train_images);
- int N = plist->size;
- paths = (char **)list_to_array(plist);
-
- if(i*imgs > N*120){
- net.layers[net.n-1].rescore = 1;
- }
data train, buffer;
- int classes = layer.classes;
- int background = layer.objectness;
- int side = sqrt(get_detection_layer_locations(layer));
+
+ layer l = net.layers[net.n - 1];
+
+ int side = l.side;
+ int classes = l.classes;
+
+ list *plist = get_paths(train_images);
+ int N = plist->size;
+ char **paths = (char **)list_to_array(plist);
load_args args = {0};
args.w = net.w;
@@ -91,12 +88,12 @@
args.m = plist->size;
args.classes = classes;
args.num_boxes = side;
- args.background = background;
args.d = &buffer;
- args.type = DETECTION_DATA;
+ args.type = REGION_DATA;
pthread_t load_thread = load_data_in_thread(args);
clock_t time;
+ //while(i*imgs < N*120){
while(get_current_batch(net) < net.max_batches){
i += 1;
time=clock();
@@ -105,36 +102,21 @@
load_thread = load_data_in_thread(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
+
+/*
+ image im = float_to_image(net.w, net.h, 3, train.X.vals[113]);
+ image copy = copy_image(im);
+ draw_swag(copy, train.y.vals[113], 7, "truth");
+ cvWaitKey(0);
+ free_image(copy);
+ */
+
time=clock();
float loss = train_network(net, train);
if (avg_loss < 0) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %f rate, %d images, epoch: %f\n", get_current_batch(net), loss, avg_loss, sec(clock()-time), get_current_rate(net), *net.seen, (float)*net.seen/N);
-
- if((i-1)*imgs <= 80*N && i*imgs > N*80){
- fprintf(stderr, "Second stage done.\n");
- char buff[256];
- sprintf(buff, "%s/%s_second_stage.weights", backup_directory, base);
- save_weights(net, buff);
- net.layers[net.n-1].joint = 1;
- net.layers[net.n-1].objectness = 0;
- background = 0;
-
- pthread_join(load_thread, 0);
- free_data(buffer);
- args.background = background;
- load_thread = load_data_in_thread(args);
- }
-
- if((i-1)*imgs <= 120*N && i*imgs > N*120){
- fprintf(stderr, "Third stage done.\n");
- char buff[256];
- sprintf(buff, "%s/%s_final.weights", backup_directory, base);
- net.layers[net.n-1].rescore = 1;
- save_weights(net, buff);
- }
-
+ printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs);
if(i%1000==0){
char buff[256];
sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
@@ -143,36 +125,38 @@
free_data(train);
}
char buff[256];
- sprintf(buff, "%s/%s_rescore.weights", backup_directory, base);
+ sprintf(buff, "%s/%s_final.weights", backup_directory, base);
save_weights(net, buff);
}
-void convert_yoloplus_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
+void convert_swag_detections(float *predictions, int classes, int num, int square, int side, int w, int h, float thresh, float **probs, box *boxes)
{
- int i,j;
- int per_box = 4+classes+(background || objectness);
- for (i = 0; i < num_boxes*num_boxes; ++i){
- float scale = 1;
- if(objectness) scale = 1-predictions[i*per_box];
- int offset = i*per_box+(background||objectness);
- for(j = 0; j < classes; ++j){
- float prob = scale*predictions[offset+j];
- probs[i][j] = (prob > thresh) ? prob : 0;
+ int i,j,n;
+ int per_cell = 5*num+classes;
+ for (i = 0; i < side*side; ++i){
+ int row = i / side;
+ int col = i % side;
+ for(n = 0; n < num; ++n){
+ int offset = i*per_cell + 5*n;
+ float scale = predictions[offset];
+ int index = i*num + n;
+ boxes[index].x = (predictions[offset + 1] + col) / side * w;
+ boxes[index].y = (predictions[offset + 2] + row) / side * h;
+ boxes[index].w = pow(predictions[offset + 3], (square?2:1)) * w;
+ boxes[index].h = pow(predictions[offset + 4], (square?2:1)) * h;
+ for(j = 0; j < classes; ++j){
+ offset = i*per_cell + 5*num;
+ float prob = scale*predictions[offset+j];
+ probs[index][j] = (prob > thresh) ? prob : 0;
+ }
}
- int row = i / num_boxes;
- int col = i % num_boxes;
- offset += classes;
- boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
- boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
- boxes[i].w = pow(predictions[offset + 2], 2) * w;
- boxes[i].h = pow(predictions[offset + 3], 2) * h;
}
}
-void print_yoloplus_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
+void print_swag_detections(FILE **fps, char *id, box *boxes, float **probs, int total, int classes, int w, int h)
{
int i, j;
- for(i = 0; i < num_boxes*num_boxes; ++i){
+ for(i = 0; i < total; ++i){
float xmin = boxes[i].x - boxes[i].w/2.;
float xmax = boxes[i].x + boxes[i].w/2.;
float ymin = boxes[i].y - boxes[i].h/2.;
@@ -190,14 +174,13 @@
}
}
-void validate_yoloplus(char *cfgfile, char *weightfile)
+void validate_swag(char *cfgfile, char *weightfile)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
load_weights(&net, weightfile);
}
set_batch_network(&net, 1);
- detection_layer layer = get_network_detection_layer(net);
fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
srand(time(0));
@@ -205,10 +188,10 @@
list *plist = get_paths("/home/pjreddie/data/voc/test/2007_test.txt");
char **paths = (char **)list_to_array(plist);
- int classes = layer.classes;
- int objectness = layer.objectness;
- int background = layer.background;
- int num_boxes = sqrt(get_detection_layer_locations(layer));
+ layer l = net.layers[net.n-1];
+ int classes = l.classes;
+ int square = l.sqrt;
+ int side = l.side;
int j;
FILE **fps = calloc(classes, sizeof(FILE *));
@@ -217,9 +200,9 @@
snprintf(buff, 1024, "%s%s.txt", base, voc_names[j]);
fps[j] = fopen(buff, "w");
}
- box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
- float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
- for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
+ box *boxes = calloc(side*side*l.n, sizeof(box));
+ float **probs = calloc(side*side*l.n, sizeof(float *));
+ for(j = 0; j < side*side*l.n; ++j) probs[j] = calloc(classes, sizeof(float *));
int m = plist->size;
int i=0;
@@ -268,9 +251,9 @@
float *predictions = network_predict(net, X);
int w = val[t].w;
int h = val[t].h;
- convert_yoloplus_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
- if (nms) do_nms(boxes, probs, num_boxes*num_boxes, classes, iou_thresh);
- print_yoloplus_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
+ convert_swag_detections(predictions, classes, l.n, square, side, w, h, thresh, probs, boxes);
+ if (nms) do_nms(boxes, probs, side*side*l.n, classes, iou_thresh);
+ print_swag_detections(fps, id, boxes, probs, side*side*l.n, classes, w, h);
free(id);
free_image(val[t]);
free_image(val_resized[t]);
@@ -279,7 +262,7 @@
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
-void test_yoloplus(char *cfgfile, char *weightfile, char *filename, float thresh)
+void test_swag(char *cfgfile, char *weightfile, char *filename, float thresh)
{
network net = parse_network_cfg(cfgfile);
@@ -306,7 +289,7 @@
time=clock();
float *predictions = network_predict(net, X);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- draw_yoloplus(im, predictions, 7, layer.objectness, "predictions", thresh);
+ draw_swag(im, predictions, 7, layer.objectness, "predictions", thresh);
free_image(im);
free_image(sized);
#ifdef OPENCV
@@ -317,7 +300,7 @@
}
}
-void run_yoloplus(int argc, char **argv)
+void run_swag(int argc, char **argv)
{
float thresh = find_float_arg(argc, argv, "-thresh", .2);
if(argc < 4){
@@ -328,7 +311,7 @@
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "test")) test_yoloplus(cfg, weights, filename, thresh);
- else if(0==strcmp(argv[2], "train")) train_yoloplus(cfg, weights);
- else if(0==strcmp(argv[2], "valid")) validate_yoloplus(cfg, weights);
+ if(0==strcmp(argv[2], "test")) test_swag(cfg, weights, filename, thresh);
+ else if(0==strcmp(argv[2], "train")) train_swag(cfg, weights);
+ else if(0==strcmp(argv[2], "valid")) validate_swag(cfg, weights);
}
--
Gitblit v1.10.0