From d1965bdb969920c85f72785ec6e1f3d7bda957de Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 14 Mar 2016 06:18:42 +0000
Subject: [PATCH] Go
---
src/classifier.c | 93 +++++++++++++++++++++++++++++++++++++++-------
1 files changed, 79 insertions(+), 14 deletions(-)
diff --git a/src/classifier.c b/src/classifier.c
index fdbe534..2e974a5 100644
--- a/src/classifier.c
+++ b/src/classifier.c
@@ -3,6 +3,7 @@
#include "parser.h"
#include "option_list.h"
#include "blas.h"
+#include <sys/time.h>
#ifdef OPENCV
#include "opencv2/highgui/highgui_c.h"
@@ -239,8 +240,8 @@
}
int w = net.w;
int h = net.h;
- image im = load_image_color(paths[i], w, h);
int shift = 32;
+ image im = load_image_color(paths[i], w+shift, h+shift);
image images[10];
images[0] = crop_image(im, -shift, -shift, w, h);
images[1] = crop_image(im, shift, -shift, w, h);
@@ -299,6 +300,7 @@
float avg_topk = 0;
int *indexes = calloc(topk, sizeof(int));
+ int size = net.w;
for(i = 0; i < m; ++i){
int class = -1;
char *path = paths[i];
@@ -309,13 +311,15 @@
}
}
image im = load_image_color(paths[i], 0, 0);
- resize_network(&net, im.w, im.h);
+ image resized = resize_min(im, size);
+ resize_network(&net, resized.w, resized.h);
//show_image(im, "orig");
//show_image(crop, "cropped");
//cvWaitKey(0);
- float *pred = network_predict(net, im.data);
+ float *pred = network_predict(net, resized.data);
free_image(im);
+ free_image(resized);
top_k(pred, classes, topk, indexes);
if(indexes[0] == class) avg_acc += 1;
@@ -406,7 +410,7 @@
char **labels = get_labels(label_list);
list *plist = get_paths(valid_list);
- int scales[] = {224, 256, 384, 480, 512};
+ int scales[] = {192, 224, 288, 320, 352};
int nscales = sizeof(scales)/sizeof(scales[0]);
char **paths = (char **)list_to_array(plist);
@@ -429,16 +433,8 @@
float *pred = calloc(classes, sizeof(float));
image im = load_image_color(paths[i], 0, 0);
for(j = 0; j < nscales; ++j){
- int w, h;
- if(im.w < im.h){
- w = scales[j];
- h = (im.h*w)/im.w;
- } else {
- h = scales[j];
- w = (im.w * h) / im.h;
- }
- resize_network(&net, w, h);
- image r = resize_image(im, w, h);
+ image r = resize_min(im, scales[j]);
+ resize_network(&net, r.w, r.h);
float *p = network_predict(net, r.data);
axpy_cpu(classes, 1, p, 1, pred, 1);
flip_image(r);
@@ -577,6 +573,73 @@
}
+void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_index, const char *filename)
+{
+#ifdef OPENCV
+ printf("Classifier Demo\n");
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ set_batch_network(&net, 1);
+ list *options = read_data_cfg(datacfg);
+
+ srand(2222222);
+ CvCapture * cap;
+
+ if(filename){
+ cap = cvCaptureFromFile(filename);
+ }else{
+ cap = cvCaptureFromCAM(cam_index);
+ }
+
+ int top = option_find_int(options, "top", 1);
+
+ char *name_list = option_find_str(options, "names", 0);
+ char **names = get_labels(name_list);
+
+ int *indexes = calloc(top, sizeof(int));
+
+ if(!cap) error("Couldn't connect to webcam.\n");
+ cvNamedWindow("Classifier", CV_WINDOW_NORMAL);
+ cvResizeWindow("Classifier", 512, 512);
+ float fps = 0;
+ int i;
+
+ while(1){
+ struct timeval tval_before, tval_after, tval_result;
+ gettimeofday(&tval_before, NULL);
+
+ image in = get_image_from_stream(cap);
+ image in_s = resize_image(in, net.w, net.h);
+ show_image(in, "Classifier");
+
+ float *predictions = network_predict(net, in_s.data);
+ top_predictions(net, top, indexes);
+
+ printf("\033[2J");
+ printf("\033[1;1H");
+ printf("\nFPS:%.0f\n",fps);
+
+ for(i = 0; i < top; ++i){
+ int index = indexes[i];
+ printf("%.1f%%: %s\n", predictions[index]*100, names[index]);
+ }
+
+ free_image(in_s);
+ free_image(in);
+
+ cvWaitKey(10);
+
+ gettimeofday(&tval_after, NULL);
+ timersub(&tval_after, &tval_before, &tval_result);
+ float curr = 1000000.f/((long int)tval_result.tv_usec);
+ fps = .9*fps + .1*curr;
+ }
+#endif
+}
+
+
void run_classifier(int argc, char **argv)
{
if(argc < 4){
@@ -584,6 +647,7 @@
return;
}
+ int cam_index = find_int_arg(argc, argv, "-c", 0);
char *data = argv[3];
char *cfg = argv[4];
char *weights = (argc > 5) ? argv[5] : 0;
@@ -592,6 +656,7 @@
int layer = layer_s ? atoi(layer_s) : -1;
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights);
+ else if(0==strcmp(argv[2], "demo")) demo_classifier(data, cfg, weights, cam_index, filename);
else if(0==strcmp(argv[2], "test")) test_classifier(data, cfg, weights, layer);
else if(0==strcmp(argv[2], "valid")) validate_classifier(data, cfg, weights);
else if(0==strcmp(argv[2], "valid10")) validate_classifier_10(data, cfg, weights);
--
Gitblit v1.10.0