From ae43c2bc32fbb838bfebeeaf2c2b058ccab5c83c Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@burninator.cs.washington.edu>
Date: Thu, 23 Jun 2016 05:31:14 +0000
Subject: [PATCH] hi
---
src/captcha.c | 307 ++++++++++++++++++++++++++++++++++++++++----------
1 files changed, 244 insertions(+), 63 deletions(-)
diff --git a/src/captcha.c b/src/captcha.c
index 40a4082..79b4e4e 100644
--- a/src/captcha.c
+++ b/src/captcha.c
@@ -2,11 +2,35 @@
#include "utils.h"
#include "parser.h"
+void fix_data_captcha(data d, int mask)
+{
+ matrix labels = d.y;
+ int i, j;
+ for(i = 0; i < d.y.rows; ++i){
+ for(j = 0; j < d.y.cols; j += 2){
+ if (mask){
+ if(!labels.vals[i][j]){
+ labels.vals[i][j] = SECRET_NUM;
+ labels.vals[i][j+1] = SECRET_NUM;
+ }else if(labels.vals[i][j+1]){
+ labels.vals[i][j] = 0;
+ }
+ } else{
+ if (labels.vals[i][j]) {
+ labels.vals[i][j+1] = 0;
+ } else {
+ labels.vals[i][j+1] = 1;
+ }
+ }
+ }
+ }
+}
void train_captcha(char *cfgfile, char *weightfile)
{
- float avg_loss = -1;
+ data_seed = time(0);
srand(time(0));
+ float avg_loss = -1;
char *base = basecfg(cfgfile);
printf("%s\n", base);
network net = parse_network_cfg(cfgfile);
@@ -15,24 +39,54 @@
}
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024;
- int i = net.seen/imgs;
- list *plist = get_paths("/data/captcha/train.base");
+ int i = *net.seen/imgs;
+ int solved = 1;
+ list *plist;
+ char **labels = get_labels("/data/captcha/reimgs.labels.list");
+ if (solved){
+ plist = get_paths("/data/captcha/reimgs.solved.list");
+ }else{
+ plist = get_paths("/data/captcha/reimgs.raw.list");
+ }
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
clock_t time;
+ pthread_t load_thread;
+ data train;
+ data buffer;
+
+ load_args args = {0};
+ args.w = net.w;
+ args.h = net.h;
+ args.paths = paths;
+ args.classes = 26;
+ args.n = imgs;
+ args.m = plist->size;
+ args.labels = labels;
+ args.d = &buffer;
+ args.type = CLASSIFICATION_DATA;
+
+ load_thread = load_data_in_thread(args);
while(1){
++i;
time=clock();
- data train = load_data_captcha(paths, imgs, plist->size, 10, 60, 200);
- translate_data_rows(train, -128);
- scale_data_rows(train, 1./128);
+ pthread_join(load_thread, 0);
+ train = buffer;
+ fix_data_captcha(train, solved);
+
+ /*
+ image im = float_to_image(256, 256, 3, train.X.vals[114]);
+ show_image(im, "training");
+ cvWaitKey(0);
+ */
+
+ load_thread = load_data_in_thread(args);
printf("Loaded: %lf seconds\n", sec(clock()-time));
time=clock();
float loss = train_network(net, train);
- net.seen += imgs;
if(avg_loss == -1) avg_loss = loss;
avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
+ printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), *net.seen);
free_data(train);
if(i%100==0){
char buff[256];
@@ -42,67 +96,190 @@
}
}
-void decode_captcha(char *cfgfile, char *weightfile)
+void test_captcha(char *cfgfile, char *weightfile, char *filename)
{
- setbuf(stdout, NULL);
- srand(time(0));
network net = parse_network_cfg(cfgfile);
- set_batch_network(&net, 1);
if(weightfile){
load_weights(&net, weightfile);
}
- char filename[256];
+ set_batch_network(&net, 1);
+ srand(2222222);
+ int i = 0;
+ char **names = get_labels("/data/captcha/reimgs.labels.list");
+ char buff[256];
+ char *input = buff;
+ int indexes[26];
while(1){
- printf("Enter filename: ");
- fgets(filename, 256, stdin);
- strtok(filename, "\n");
- image im = load_image_color(filename, 60, 200);
- scale_image(im, 1./255.);
+ if(filename){
+ strncpy(input, filename, 256);
+ }else{
+ //printf("Enter Image Path: ");
+ //fflush(stdout);
+ input = fgets(input, 256, stdin);
+ if(!input) return;
+ strtok(input, "\n");
+ }
+ image im = load_image_color(input, net.w, net.h);
float *X = im.data;
float *predictions = network_predict(net, X);
- image out = float_to_image(60, 200, 3, predictions);
- show_image(out, "decoded");
- cvWaitKey(0);
+ top_predictions(net, 26, indexes);
+ //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
+ for(i = 0; i < 26; ++i){
+ int index = indexes[i];
+ if(i != 0) printf(", ");
+ printf("%s %f", names[index], predictions[index]);
+ }
+ printf("\n");
+ fflush(stdout);
free_image(im);
+ if (filename) break;
}
}
+void valid_captcha(char *cfgfile, char *weightfile, char *filename)
+{
+ char **labels = get_labels("/data/captcha/reimgs.labels.list");
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ list *plist = get_paths("/data/captcha/reimgs.fg.list");
+ char **paths = (char **)list_to_array(plist);
+ int N = plist->size;
+ int outputs = net.outputs;
+
+ set_batch_network(&net, 1);
+ srand(2222222);
+ int i, j;
+ for(i = 0; i < N; ++i){
+ if (i%100 == 0) fprintf(stderr, "%d\n", i);
+ image im = load_image_color(paths[i], net.w, net.h);
+ float *X = im.data;
+ float *predictions = network_predict(net, X);
+ //printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
+ int truth = -1;
+ for(j = 0; j < 13; ++j){
+ if (strstr(paths[i], labels[j])) truth = j;
+ }
+ if (truth == -1){
+ fprintf(stderr, "bad: %s\n", paths[i]);
+ return;
+ }
+ printf("%d, ", truth);
+ for(j = 0; j < outputs; ++j){
+ if (j != 0) printf(", ");
+ printf("%f", predictions[j]);
+ }
+ printf("\n");
+ fflush(stdout);
+ free_image(im);
+ if (filename) break;
+ }
+}
+
+/*
+ void train_captcha(char *cfgfile, char *weightfile)
+ {
+ float avg_loss = -1;
+ srand(time(0));
+ char *base = basecfg(cfgfile);
+ printf("%s\n", base);
+ network net = parse_network_cfg(cfgfile);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+ int imgs = 1024;
+ int i = net.seen/imgs;
+ list *plist = get_paths("/data/captcha/train.auto5");
+ char **paths = (char **)list_to_array(plist);
+ printf("%d\n", plist->size);
+ clock_t time;
+ while(1){
+ ++i;
+ time=clock();
+ data train = load_data_captcha(paths, imgs, plist->size, 10, 200, 60);
+ translate_data_rows(train, -128);
+ scale_data_rows(train, 1./128);
+ printf("Loaded: %lf seconds\n", sec(clock()-time));
+ time=clock();
+ float loss = train_network(net, train);
+ net.seen += imgs;
+ if(avg_loss == -1) avg_loss = loss;
+ avg_loss = avg_loss*.9 + loss*.1;
+ printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
+ free_data(train);
+ if(i%10==0){
+ char buff[256];
+ sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
+ save_weights(net, buff);
+ }
+ }
+ }
+
+ void decode_captcha(char *cfgfile, char *weightfile)
+ {
+ setbuf(stdout, NULL);
+ srand(time(0));
+ network net = parse_network_cfg(cfgfile);
+ set_batch_network(&net, 1);
+ if(weightfile){
+ load_weights(&net, weightfile);
+ }
+ char filename[256];
+ while(1){
+ printf("Enter filename: ");
+ fgets(filename, 256, stdin);
+ strtok(filename, "\n");
+ image im = load_image_color(filename, 300, 57);
+ scale_image(im, 1./255.);
+ float *X = im.data;
+ float *predictions = network_predict(net, X);
+ image out = float_to_image(300, 57, 1, predictions);
+ show_image(out, "decoded");
+#ifdef OPENCV
+cvWaitKey(0);
+#endif
+free_image(im);
+}
+}
+
void encode_captcha(char *cfgfile, char *weightfile)
{
- float avg_loss = -1;
- srand(time(0));
- char *base = basecfg(cfgfile);
- printf("%s\n", base);
- network net = parse_network_cfg(cfgfile);
- if(weightfile){
- load_weights(&net, weightfile);
+float avg_loss = -1;
+srand(time(0));
+char *base = basecfg(cfgfile);
+printf("%s\n", base);
+network net = parse_network_cfg(cfgfile);
+if(weightfile){
+ load_weights(&net, weightfile);
+}
+printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
+int imgs = 1024;
+int i = net.seen/imgs;
+list *plist = get_paths("/data/captcha/encode.list");
+char **paths = (char **)list_to_array(plist);
+printf("%d\n", plist->size);
+clock_t time;
+while(1){
+ ++i;
+ time=clock();
+ data train = load_data_captcha_encode(paths, imgs, plist->size, 300, 57);
+ scale_data_rows(train, 1./255);
+ printf("Loaded: %lf seconds\n", sec(clock()-time));
+ time=clock();
+ float loss = train_network(net, train);
+ net.seen += imgs;
+ if(avg_loss == -1) avg_loss = loss;
+ avg_loss = avg_loss*.9 + loss*.1;
+ printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
+ free_matrix(train.X);
+ if(i%100==0){
+ char buff[256];
+ sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
+ save_weights(net, buff);
}
- printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
- int imgs = 1024;
- int i = net.seen/imgs;
- list *plist = get_paths("/data/captcha/encode.list");
- char **paths = (char **)list_to_array(plist);
- printf("%d\n", plist->size);
- clock_t time;
- while(1){
- ++i;
- time=clock();
- data train = load_data_captcha_encode(paths, imgs, plist->size, 60, 200);
- scale_data_rows(train, 1./255);
- printf("Loaded: %lf seconds\n", sec(clock()-time));
- time=clock();
- float loss = train_network(net, train);
- net.seen += imgs;
- if(avg_loss == -1) avg_loss = loss;
- avg_loss = avg_loss*.9 + loss*.1;
- printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen);
- free_matrix(train.X);
- if(i%100==0){
- char buff[256];
- sprintf(buff, "/home/pjreddie/imagenet_backup/%s_%d.weights",base, i);
- save_weights(net, buff);
- }
- }
+}
}
void validate_captcha(char *cfgfile, char *weightfile)
@@ -114,11 +291,11 @@
if(weightfile){
load_weights(&net, weightfile);
}
- int imgs = 1000;
int numchars = 37;
- list *plist = get_paths("/data/captcha/valid.base");
+ list *plist = get_paths("/data/captcha/solved.hard");
char **paths = (char **)list_to_array(plist);
- data valid = load_data_captcha(paths, imgs, 0, 10, 60, 200);
+ int imgs = plist->size;
+ data valid = load_data_captcha(paths, imgs, 0, 10, 200, 60);
translate_data_rows(valid, -128);
scale_data_rows(valid, 1./128);
matrix pred = network_predict_data(net, valid);
@@ -157,7 +334,7 @@
//printf("Enter filename: ");
fgets(filename, 256, stdin);
strtok(filename, "\n");
- image im = load_image_color(filename, 60, 200);
+ image im = load_image_color(filename, 200, 60);
translate_image(im, -128);
scale_image(im, 1/128.);
float *X = im.data;
@@ -166,6 +343,7 @@
free_image(im);
}
}
+ */
void run_captcha(int argc, char **argv)
{
if(argc < 4){
@@ -175,10 +353,13 @@
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
- if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
- else if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
- else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
- else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
- else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
+ char *filename = (argc > 5) ? argv[5]: 0;
+ if(0==strcmp(argv[2], "train")) train_captcha(cfg, weights);
+ else if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights, filename);
+ else if(0==strcmp(argv[2], "valid")) valid_captcha(cfg, weights, filename);
+ //if(0==strcmp(argv[2], "test")) test_captcha(cfg, weights);
+ //else if(0==strcmp(argv[2], "encode")) encode_captcha(cfg, weights);
+ //else if(0==strcmp(argv[2], "decode")) decode_captcha(cfg, weights);
+ //else if(0==strcmp(argv[2], "valid")) validate_captcha(cfg, weights);
}
--
Gitblit v1.10.0