From f98efe6c32a064e77712f25fb40a673c3249cfd4 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 10 Jul 2015 23:34:38 +0000
Subject: [PATCH] what happened?
---
src/detection.c | 143 +++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 143 insertions(+), 0 deletions(-)
diff --git a/src/detection.c b/src/detection.c
index e21e120..2553115 100644
--- a/src/detection.c
+++ b/src/detection.c
@@ -3,6 +3,7 @@
#include "cost_layer.h"
#include "utils.h"
#include "parser.h"
+#include "box.h"
char *class_names[] = {"aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"};
@@ -206,6 +207,147 @@
fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
}
+
+void convert_detections(float *predictions, int classes, int objectness, int background, int num_boxes, int w, int h, float thresh, float **probs, box *boxes)
+{
+ int i,j;
+ int per_box = 4+classes+(background || objectness);
+ for (i = 0; i < num_boxes*num_boxes; ++i){
+ float scale = 1;
+ if(objectness) scale = 1-predictions[i*per_box];
+ int offset = i*per_box+(background||objectness);
+ for(j = 0; j < classes; ++j){
+ float prob = scale*predictions[offset+j];
+ probs[i][j] = (prob > thresh) ? prob : 0;
+ }
+ int row = i / num_boxes;
+ int col = i % num_boxes;
+ offset += classes;
+ boxes[i].x = (predictions[offset + 0] + col) / num_boxes * w;
+ boxes[i].y = (predictions[offset + 1] + row) / num_boxes * h;
+ boxes[i].w = pow(predictions[offset + 2], 2) * w;
+ boxes[i].h = pow(predictions[offset + 3], 2) * h;
+ }
+}
+
+void do_nms(box *boxes, float **probs, int num_boxes, int classes, float thresh)
+{
+ int i, j, k;
+ for(i = 0; i < num_boxes*num_boxes; ++i){
+ int any = 0;
+ for(k = 0; k < classes; ++k) any = any || (probs[i][k] > 0);
+ if(!any) {
+ continue;
+ }
+ for(j = i+1; j < num_boxes*num_boxes; ++j){
+ if (box_iou(boxes[i], boxes[j]) > thresh){
+ for(k = 0; k < classes; ++k){
+ if (probs[i][k] < probs[j][k]) probs[i][k] = 0;
+ else probs[j][k] = 0;
+ }
+ }
+ }
+ }
+}
+
+void print_detections(FILE **fps, char *id, box *boxes, float **probs, int num_boxes, int classes, int w, int h)
+{
+ int i, j;
+ for(i = 0; i < num_boxes*num_boxes; ++i){
+ float xmin = boxes[i].x - boxes[i].w/2.;
+ float xmax = boxes[i].x + boxes[i].w/2.;
+ float ymin = boxes[i].y - boxes[i].h/2.;
+ float ymax = boxes[i].y + boxes[i].h/2.;
+
+ if (xmin < 0) xmin = 0;
+ if (ymin < 0) ymin = 0;
+ if (xmax > w) xmax = w;
+ if (ymax > h) ymax = h;
+
+ for(j = 0; j < classes; ++j){
+ if (probs[i][j]) fprintf(fps[j], "%s %f %f %f %f %f\n", id, probs[i][j],
+ xmin, ymin, xmax, ymax);
+ }
+ }
+}
+
+void valid_detection(char *cfgfile, char *weightfile)
+{
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ set_batch_network(&net, 1);
+ detection_layer layer = get_network_detection_layer(net);
+ fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ srand(time(0));
+
+ char *base = "/home/pjreddie/data/voc/devkit/results/VOC2012/Main/comp4_det_test_";
+ list *plist = get_paths("/home/pjreddie/data/voc/test.txt");
+ char **paths = (char **)list_to_array(plist);
+
+ int classes = layer.classes;
+ int objectness = layer.objectness;
+ int background = layer.background;
+ int num_boxes = sqrt(get_detection_layer_locations(layer));
+
+ int j;
+ FILE **fps = calloc(classes, sizeof(FILE *));
+ for(j = 0; j < classes; ++j){
+ char buff[1024];
+ snprintf(buff, 1024, "%s%s.txt", base, class_names[j]);
+ fps[j] = fopen(buff, "w");
+ }
+ box *boxes = calloc(num_boxes*num_boxes, sizeof(box));
+ float **probs = calloc(num_boxes*num_boxes, sizeof(float *));
+ for(j = 0; j < num_boxes*num_boxes; ++j) probs[j] = calloc(classes, sizeof(float *));
+
+ int m = plist->size;
+ int i=0;
+ int t;
+
+ float thresh = .001;
+ int nms = 1;
+ float iou_thresh = .5;
+
+ int nthreads = 8;
+ image *val = calloc(nthreads, sizeof(image));
+ image *val_resized = calloc(nthreads, sizeof(image));
+ image *buf = calloc(nthreads, sizeof(image));
+ image *buf_resized = calloc(nthreads, sizeof(image));
+ pthread_t *thr = calloc(nthreads, sizeof(pthread_t));
+ for(t = 0; t < nthreads; ++t){
+ thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
+ }
+ time_t start = time(0);
+ for(i = nthreads; i < m+nthreads; i += nthreads){
+ fprintf(stderr, "%d\n", i);
+ for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
+ pthread_join(thr[t], 0);
+ val[t] = buf[t];
+ val_resized[t] = buf_resized[t];
+ }
+ for(t = 0; t < nthreads && i+t < m; ++t){
+ thr[t] = load_image_thread(paths[i+t], &buf[t], &buf_resized[t], net.w, net.h);
+ }
+ for(t = 0; t < nthreads && i+t-nthreads < m; ++t){
+ char *path = paths[i+t-nthreads];
+ char *id = basecfg(path);
+ float *X = val_resized[t].data;
+ float *predictions = network_predict(net, X);
+ int w = val[t].w;
+ int h = val[t].h;
+ convert_detections(predictions, classes, objectness, background, num_boxes, w, h, thresh, probs, boxes);
+ if (nms) do_nms(boxes, probs, num_boxes, classes, iou_thresh);
+ print_detections(fps, id, boxes, probs, num_boxes, classes, w, h);
+ free(id);
+ free_image(val[t]);
+ free_image(val_resized[t]);
+ }
+ }
+ fprintf(stderr, "Total Detection Time: %f Seconds\n", (double)(time(0) - start));
+}
+
void test_detection(char *cfgfile, char *weightfile, char *filename)
{
@@ -259,4 +401,5 @@
if(0==strcmp(argv[2], "test")) test_detection(cfg, weights, filename);
else if(0==strcmp(argv[2], "train")) train_detection(cfg, weights);
else if(0==strcmp(argv[2], "valid")) validate_detection(cfg, weights);
+ else if(0==strcmp(argv[2], "run")) valid_detection(cfg, weights);
}
--
Gitblit v1.10.0