From 68213b835b9f15cb449ad2037a8b51c17a3de07b Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Mon, 14 Mar 2016 22:10:14 +0000
Subject: [PATCH] Makefile
---
src/data.c | 176 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 170 insertions(+), 6 deletions(-)
diff --git a/src/data.c b/src/data.c
index 88c8991..4d52d11 100644
--- a/src/data.c
+++ b/src/data.c
@@ -82,6 +82,32 @@
return X;
}
+matrix load_image_cropped_paths(char **paths, int n, int min, int max, int size)
+{
+ int i;
+ matrix X;
+ X.rows = n;
+ X.vals = calloc(X.rows, sizeof(float*));
+ X.cols = 0;
+
+ for(i = 0; i < n; ++i){
+ image im = load_image_color(paths[i], 0, 0);
+ image crop = random_crop_image(im, min, max, size);
+ int flip = rand_r(&data_seed)%2;
+ if (flip) flip_image(crop);
+ /*
+ show_image(im, "orig");
+ show_image(crop, "crop");
+ cvWaitKey(0);
+ */
+ free_image(im);
+ X.vals[i] = crop.data;
+ X.cols = crop.h*crop.w*crop.c;
+ }
+ return X;
+}
+
+
box_label *read_boxes(char *filename, int *n)
{
box_label *boxes = calloc(1, sizeof(box_label));
@@ -386,6 +412,33 @@
return y;
}
+matrix load_tags_paths(char **paths, int n, int k)
+{
+ matrix y = make_matrix(n, k);
+ int i;
+ int count = 0;
+ for(i = 0; i < n; ++i){
+ char *label = find_replace(paths[i], "imgs", "labels");
+ label = find_replace(label, "_iconl.jpeg", ".txt");
+ FILE *file = fopen(label, "r");
+ if(!file){
+ label = find_replace(label, "labels", "labels2");
+ file = fopen(label, "r");
+ if(!file) continue;
+ }
+ ++count;
+ int tag;
+ while(fscanf(file, "%d", &tag) == 1){
+ if(tag < k){
+ y.vals[i][tag] = 1;
+ }
+ }
+ fclose(file);
+ }
+ printf("%d/%d\n", count, n);
+ return y;
+}
+
char **get_labels(char *filename)
{
list *plist = get_paths(filename);
@@ -641,8 +694,10 @@
//printf("Loading data: %d\n", rand_r(&data_seed));
load_args a = *(struct load_args*)ptr;
- if (a.type == CLASSIFICATION_DATA){
+ if (a.type == OLD_CLASSIFICATION_DATA){
*a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
+ } else if (a.type == CLASSIFICATION_DATA){
+ *a.d = load_data_augment(a.paths, a.n, a.m, a.labels, a.classes, a.min, a.max, a.size);
} else if (a.type == DETECTION_DATA){
*a.d = load_data_detection(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background);
} else if (a.type == WRITING_DATA){
@@ -656,6 +711,9 @@
} else if (a.type == IMAGE_DATA){
*(a.im) = load_image_color(a.path, 0, 0);
*(a.resized) = resize_image(*(a.im), a.w, a.h);
+ } else if (a.type == TAG_DATA){
+ *a.d = load_data_tag(a.paths, a.n, a.m, a.classes, a.min, a.max, a.size);
+ //*a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
}
free(ptr);
return 0;
@@ -696,6 +754,30 @@
return d;
}
+data load_data_augment(char **paths, int n, int m, char **labels, int k, int min, int max, int size)
+{
+ if(m) paths = get_random_paths(paths, n, m);
+ data d;
+ d.shallow = 0;
+ d.X = load_image_cropped_paths(paths, n, min, max, size);
+ d.y = load_labels_paths(paths, n, labels, k);
+ if(m) free(paths);
+ return d;
+}
+
+data load_data_tag(char **paths, int n, int m, int k, int min, int max, int size)
+{
+ if(m) paths = get_random_paths(paths, n, m);
+ data d = {0};
+ d.w = size;
+ d.h = size;
+ d.shallow = 0;
+ d.X = load_image_cropped_paths(paths, n, min, max, size);
+ d.y = load_tags_paths(paths, n, k);
+ if(m) free(paths);
+ return d;
+}
+
matrix concat_matrix(matrix m1, matrix m2)
{
int i, count = 0;
@@ -759,8 +841,8 @@
X.vals[i][j] = (double)bytes[j+1];
}
}
- translate_data_rows(d, -128);
- scale_data_rows(d, 1./128);
+ //translate_data_rows(d, -128);
+ scale_data_rows(d, 1./255);
//normalize_data_rows(d);
fclose(fp);
return d;
@@ -786,6 +868,17 @@
}
}
+void smooth_data(data d)
+{
+ int i, j;
+ float scale = 1. / d.y.cols;
+ float eps = .1;
+ for(i = 0; i < d.y.rows; ++i){
+ for(j = 0; j < d.y.cols; ++j){
+ d.y.vals[i][j] = eps * scale + (1-eps) * d.y.vals[i][j];
+ }
+ }
+}
data load_all_cifar10()
{
@@ -800,7 +893,7 @@
for(b = 0; b < 5; ++b){
char buff[256];
- sprintf(buff, "data/cifar10/data_batch_%d.bin", b+1);
+ sprintf(buff, "data/cifar/cifar-10-batches-bin/data_batch_%d.bin", b+1);
FILE *fp = fopen(buff, "rb");
if(!fp) file_error(buff);
for(i = 0; i < 10000; ++i){
@@ -815,11 +908,59 @@
fclose(fp);
}
//normalize_data_rows(d);
- translate_data_rows(d, -128);
- scale_data_rows(d, 1./128);
+ //translate_data_rows(d, -128);
+ scale_data_rows(d, 1./255);
+ // smooth_data(d);
return d;
}
+data load_go(char *filename)
+{
+ FILE *fp = fopen(filename, "rb");
+ matrix X = make_matrix(3363059, 361);
+ matrix y = make_matrix(3363059, 361);
+ int row, col;
+
+ if(!fp) file_error(filename);
+ char *label;
+ int count = 0;
+ while((label = fgetl(fp))){
+ int i;
+ if(count == X.rows){
+ X = resize_matrix(X, count*2);
+ y = resize_matrix(y, count*2);
+ }
+ sscanf(label, "%d %d", &row, &col);
+ char *board = fgetl(fp);
+
+ int index = row*19 + col;
+ y.vals[count][index] = 1;
+
+ for(i = 0; i < 19*19; ++i){
+ float val = 0;
+ if(board[i] == '1') val = 1;
+ else if(board[i] == '2') val = -1;
+ X.vals[count][i] = val;
+ }
+ ++count;
+ free(label);
+ free(board);
+ }
+ X = resize_matrix(X, count);
+ y = resize_matrix(y, count);
+
+ data d;
+ d.shallow = 0;
+ d.X = X;
+ d.y = y;
+
+
+ fclose(fp);
+
+ return d;
+}
+
+
void randomize_data(data d)
{
int i;
@@ -859,6 +1000,29 @@
}
}
+data get_random_data(data d, int num)
+{
+ data r = {0};
+ r.shallow = 1;
+
+ r.X.rows = num;
+ r.y.rows = num;
+
+ r.X.cols = d.X.cols;
+ r.y.cols = d.y.cols;
+
+ r.X.vals = calloc(num, sizeof(float *));
+ r.y.vals = calloc(num, sizeof(float *));
+
+ int i;
+ for(i = 0; i < num; ++i){
+ int index = rand()%d.X.rows;
+ r.X.vals[i] = d.X.vals[index];
+ r.y.vals[i] = d.y.vals[index];
+ }
+ return r;
+}
+
data *split_data(data d, int part, int total)
{
data *split = calloc(2, sizeof(data));
--
Gitblit v1.10.0