From ae43c2bc32fbb838bfebeeaf2c2b058ccab5c83c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@burninator.cs.washington.edu>
Date: Thu, 23 Jun 2016 05:31:14 +0000
Subject: [PATCH] hi
---
src/writing.c | 110 +++++++++++++++++++++++++++++-------------------------
1 files changed, 59 insertions(+), 51 deletions(-)
diff --git a/src/writing.c b/src/writing.c
index 71dd53b..32e4f6d 100644
--- a/src/writing.c
+++ b/src/writing.c
@@ -25,16 +25,18 @@
clock_t time;
int N = plist->size;
printf("N: %d\n", N);
+ image out = get_network_image(net);
data train, buffer;
load_args args = {0};
args.w = net.w;
args.h = net.h;
+ args.out_w = out.w;
+ args.out_h = out.h;
args.paths = paths;
args.n = imgs;
args.m = N;
- args.downsample = 1;
args.d = &buffer;
args.type = WRITING_DATA;
@@ -51,9 +53,9 @@
float loss = train_network(net, train);
/*
- image pred = float_to_image(64, 64, 1, out);
- print_image(pred);
- */
+ image pred = float_to_image(64, 64, 1, out);
+ print_image(pred);
+ */
/*
image im = float_to_image(256, 256, 3, train.X.vals[0]);
@@ -69,22 +71,22 @@
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
printf("%d, %.3f: %f, %f avg, %f rate, %lf seconds, %d images\n", get_current_batch(net), (float)(*net.seen)/N, loss, avg_loss, get_current_rate(net), sec(clock()-time), *net.seen);
- free_data(train);
- if(get_current_batch(net)%100 == 0){
- char buff[256];
- sprintf(buff, "%s/%s_batch_%d.weights", backup_directory, base, get_current_batch(net));
- save_weights(net, buff);
- }
- if(*net.seen/N > epoch){
- epoch = *net.seen/N;
- char buff[256];
- sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
- save_weights(net, buff);
- }
+ free_data(train);
+ if(get_current_batch(net)%100 == 0){
+ char buff[256];
+ sprintf(buff, "%s/%s_batch_%d.weights", backup_directory, base, get_current_batch(net));
+ save_weights(net, buff);
+ }
+ if(*net.seen/N > epoch){
+ epoch = *net.seen/N;
+ char buff[256];
+ sprintf(buff, "%s/%s_%d.weights",backup_directory,base, epoch);
+ save_weights(net, buff);
+ }
}
}
-void test_writing(char *cfgfile, char *weightfile, char *outfile)
+void test_writing(char *cfgfile, char *weightfile, char *filename)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
@@ -93,51 +95,57 @@
set_batch_network(&net, 1);
srand(2222222);
clock_t time;
- char filename[256];
+ char buff[256];
+ char *input = buff;
+ while(1){
+ if(filename){
+ strncpy(input, filename, 256);
+ }else{
+ printf("Enter Image Path: ");
+ fflush(stdout);
+ input = fgets(input, 256, stdin);
+ if(!input) return;
+ strtok(input, "\n");
+ }
- fgets(filename, 256, stdin);
- strtok(filename, "\n");
- image im = load_image_color(filename, 0, 0);
- //image im = load_image_color("/home/pjreddie/darknet/data/figs/C02-1001-Figure-1.png", 0, 0);
- image sized = resize_image(im, net.w, net.h);
- printf("%d %d %d\n", im.h, im.w, im.c);
- float *X = sized.data;
- time=clock();
- network_predict(net, X);
- printf("%s: Predicted in %f seconds.\n", filename, sec(clock()-time));
- image pred = get_network_image(net);
+ image im = load_image_color(input, 0, 0);
+ resize_network(&net, im.w, im.h);
+ printf("%d %d %d\n", im.h, im.w, im.c);
+ float *X = im.data;
+ time=clock();
+ network_predict(net, X);
+ printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
+ image pred = get_network_image(net);
- image t = threshold_image(pred, .5);
- free_image(pred);
- pred = t;
+ image upsampled = resize_image(pred, im.w, im.h);
+ image thresh = threshold_image(upsampled, .5);
+ pred = thresh;
- if (outfile) {
- printf("Save image as %s.png (shape: %d %d)\n", outfile, pred.w, pred.h);
- save_image(pred, outfile);
- } else {
- show_image(sized, "orig");
show_image(pred, "prediction");
+ show_image(im, "orig");
#ifdef OPENCV
- cvWaitKey(0);
- cvDestroyAllWindows();
+ cvWaitKey(0);
+ cvDestroyAllWindows();
#endif
- }
- free_image(im);
- free_image(sized);
+ free_image(upsampled);
+ free_image(thresh);
+ free_image(im);
+ if (filename) break;
+ }
}
void run_writing(int argc, char **argv)
{
- if(argc < 4){
- fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
- return;
- }
+ if(argc < 4){
+ fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
+ return;
+ }
- char *cfg = argv[3];
- char *weights = (argc > 4) ? argv[4] : 0;
- char *outfile = (argc > 5) ? argv[5] : 0;
- if(0==strcmp(argv[2], "train")) train_writing(cfg, weights);
- else if(0==strcmp(argv[2], "test")) test_writing(cfg, weights, outfile);
+ char *cfg = argv[3];
+ char *weights = (argc > 4) ? argv[4] : 0;
+ char *filename = (argc > 5) ? argv[5] : 0;
+ if(0==strcmp(argv[2], "train")) train_writing(cfg, weights);
+ else if(0==strcmp(argv[2], "test")) test_writing(cfg, weights, filename);
}
--
Gitblit v1.10.0