From c7b10ceadb1a78e7480d281444a31ae2a7dc1b05 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 06 May 2016 23:25:16 +0000
Subject: [PATCH] so much need to commit
---
src/captcha.c | 179 ++++++++++++++++++++++-------------------------------------
1 files changed, 68 insertions(+), 111 deletions(-)
diff --git a/src/captcha.c b/src/captcha.c
index ccefa45..79b4e4e 100644
--- a/src/captcha.c
+++ b/src/captcha.c
@@ -26,7 +26,7 @@
}
}
-void train_captcha2(char *cfgfile, char *weightfile)
+void train_captcha(char *cfgfile, char *weightfile)
{
data_seed = time(0);
srand(time(0));
@@ -38,16 +38,15 @@
load_weights(&net, weightfile);
}
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- //net.seen=0;
int imgs = 1024;
- int i = net.seen/imgs;
+ int i = *net.seen/imgs;
int solved = 1;
list *plist;
- char **labels = get_labels("/data/captcha/reimgs.labels2.list");
+ char **labels = get_labels("/data/captcha/reimgs.labels.list");
if (solved){
plist = get_paths("/data/captcha/reimgs.solved.list");
}else{
- plist = get_paths("/data/captcha/reimgs.train.list");
+ plist = get_paths("/data/captcha/reimgs.raw.list");
}
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
@@ -55,7 +54,19 @@
pthread_t load_thread;
data train;
data buffer;
- load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer);
+
+ load_args args = {0};
+ args.w = net.w;
+ args.h = net.h;
+ args.paths = paths;
+ args.classes = 26;
+ args.n = imgs;
+ args.m = plist->size;
+ args.labels = labels;
+ args.d = &buffer;
+ args.type = CLASSIFICATION_DATA;
+
+ load_thread = load_data_in_thread(args);
while(1){
++i;
time=clock();
@@ -69,107 +80,13 @@
cvWaitKey(0);
*/
- load_thread = load_data_thread(paths, imgs, plist->size, labels, 26, net.w, net.h, &buffer);
+ load_thread = load_data_in_thread(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
- net.seen += imgs;
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
- free_data(train);
- if(i%100==0){
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
- save_weights(net, buff);
- }
- }
-}
-
-void test_captcha2(char *cfgfile, char *weightfile, char *filename)
-{
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- set_batch_network(&net, 1);
- srand(2222222);
- int i = 0;
- char **names = get_labels("/data/captcha/reimgs.labels2.list");
- clock_t time;
- char input[256];
- int indexes[26];
- while(1){
- if(filename){
- strncpy(input, filename, 256);
- }else{
- //printf("Enter Image Path: ");
- //fflush(stdout);
- fgets(input, 256, stdin);
- strtok(input, "\n");
- }
- image im = load_image_color(input, net.w, net.h);
- float *X = im.data;
- time=clock();
- float *predictions = network_predict(net, X);
- top_predictions(net, 26, indexes);
- //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- for(i = 0; i < 26; ++i){
- int index = indexes[i];
- if(i != 0) printf(", ");
- printf("%s %f", names[index], predictions[index]);
- }
- printf("\n");
- fflush(stdout);
- free_image(im);
- if (filename) break;
- }
-}
-
-void train_captcha(char *cfgfile, char *weightfile)
-{
- data_seed = time(0);
- srand(time(0));
- float avg_loss = -1;
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
- }
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- //net.seen=0;
- int imgs = 1024;
- int i = net.seen/imgs;
- char **labels = get_labels("/data/captcha/reimgs.labels.list");
- list *plist = get_paths("/data/captcha/reimgs.train.list");
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- pthread_t load_thread;
- data train;
- data buffer;
- load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer);
- while(1){
- ++i;
- time=clock();
- pthread_join(load_thread, 0);
- train = buffer;
-
- /*
- image im = float_to_image(256, 256, 3, train.X.vals[114]);
- show_image(im, "training");
- cvWaitKey(0);
- */
-
- load_thread = load_data_thread(paths, imgs, plist->size, labels, 13, net.w, net.h, &buffer);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- net.seen += imgs;
- if(avg_loss == -1) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
+ printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
free_data(train);
if(i%100==0){
char buff[256];
@@ -189,25 +106,25 @@
srand(2222222);
int i = 0;
char **names = get_labels("/data/captcha/reimgs.labels.list");
- clock_t time;
- char input[256];
- int indexes[13];
+ char buff[256];
+ char *input = buff;
+ int indexes[26];
while(1){
if(filename){
strncpy(input, filename, 256);
}else{
//printf("Enter Image Path: ");
//fflush(stdout);
- fgets(input, 256, stdin);
+ input = fgets(input, 256, stdin);
+ if(!input) return;
strtok(input, "\n");
}
image im = load_image_color(input, net.w, net.h);
float *X = im.data;
- time=clock();
float *predictions = network_predict(net, X);
- top_predictions(net, 13, indexes);
+ top_predictions(net, 26, indexes);
//printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
- for(i = 0; i < 13; ++i){
+ for(i = 0; i < 26; ++i){
int index = indexes[i];
if(i != 0) printf(", ");
printf("%s %f", names[index], predictions[index]);
@@ -219,7 +136,46 @@
}
}
+void valid_captcha(char *cfgfile, char *weightfile, char *filename)
+{
+ char **labels = get_labels("/data/captcha/reimgs.labels.list");
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ list *plist = get_paths("/data/captcha/reimgs.fg.list");
+ char **paths = (char **)list_to_array(plist);
+ int N = plist->size;
+ int outputs = net.outputs;
+ set_batch_network(&net, 1);
+ srand(2222222);
+ int i, j;
+ for(i = 0; i < N; ++i){
+ if (i%100 == 0) fprintf(stderr, "%d\n", i);
+ image im = load_image_color(paths[i], net.w, net.h);
+ float *X = im.data;
+ float *predictions = network_predict(net, X);
+ //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
+ int truth = -1;
+ for(j = 0; j < 13; ++j){
+ if (strstr(paths[i], labels[j])) truth = j;
+ }
+ if (truth == -1){
+ fprintf(stderr, "bad: %s\n", paths[i]);
+ return;
+ }
+ printf("%d, ", truth);
+ for(j = 0; j < outputs; ++j){
+ if (j != 0) printf(", ");
+ printf("%f", predictions[j]);
+ }
+ printf("\n");
+ fflush(stdout);
+ free_image(im);
+ if (filename) break;
+ }
+}
/*
void train_captcha(char *cfgfile, char *weightfile)
@@ -398,8 +354,9 @@
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *filename = (argc > 5) ? argv[5]: 0;
- if(0==strcmp(argv[2], "train")) train_captcha2(cfg, weights);
- else if(0==strcmp(argv[2], "test")) test_captcha2(cfg, weights, filename);
+ if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
+ else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
+ else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
//if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
//else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
//else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
--
Gitblit v1.10.0