From 37d7c1e79f65a75caf87e29a562d30c51cd654e5 Mon Sep 17 00:00:00 2001
From: Joe Redmon <pjreddie@gmail.com>
Date: Thu, 26 Nov 2015 21:52:56 +0000
Subject: [PATCH] fixed label linking

---
 src/data.c |   90 +++++++++++++++++++++++++++++++++------------
 1 files changed, 66 insertions(+), 24 deletions(-)

diff --git a/src/data.c b/src/data.c
index ec2b304..9b84c5a 100644
--- a/src/data.c
+++ b/src/data.c
@@ -54,7 +54,12 @@
     X.cols = 0;
 
     for(i = 0; i < n; ++i){
-        image im = load_image(paths[i], w, h, 1);
+        image im = load_image(paths[i], w, h, 3);
+
+        image gray = grayscale_image(im);
+        free_image(im);
+        im = gray;
+
         X.vals[i] = im.data;
         X.cols = im.h*im.w*im.c;
     }
@@ -148,7 +153,9 @@
 {
     char *labelpath = find_replace(path, "images", "labels");
     labelpath = find_replace(labelpath, "JPEGImages", "labels");
+
     labelpath = find_replace(labelpath, ".jpg", ".txt");
+    labelpath = find_replace(labelpath, ".JPG", ".txt");
     labelpath = find_replace(labelpath, ".JPEG", ".txt");
     int count = 0;
     box_label *boxes = read_boxes(labelpath, &count);
@@ -176,8 +183,10 @@
         int index = (col+row*num_boxes)*(5+classes);
         if (truth[index]) continue;
         truth[index++] = 1;
-        if (classes) truth[index+id] = 1;
+
+        if (id < classes) truth[index+id] = 1;
         index += classes;
+
         truth[index++] = x;
         truth[index++] = y;
         truth[index++] = w;
@@ -359,7 +368,7 @@
     }
 }
 
-data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes)
+data load_data_region(int n, char **paths, int m, int w, int h, int size, int classes, float jitter)
 {
     char **random_paths = get_random_paths(paths, n, m);
     int i;
@@ -378,8 +387,8 @@
         int oh = orig.h;
         int ow = orig.w;
 
-        int dw = ow/10;
-        int dh = oh/10;
+        int dw = (ow*jitter);
+        int dh = (oh*jitter);
 
         int pleft  = (rand_uniform() * 2*dw - dw);
         int pright = (rand_uniform() * 2*dw - dw);
@@ -413,8 +422,8 @@
 
 data load_data_compare(int n, char **paths, int m, int classes, int w, int h)
 {
-    char **random_paths = get_random_paths(paths, 2*n, m);
-    int i;
+    if(m) paths = get_random_paths(paths, 2*n, m);
+    int i,j;
     data d;
     d.shallow = 0;
 
@@ -425,20 +434,51 @@
     int k = 2*(classes);
     d.y = make_matrix(n, k);
     for(i = 0; i < n; ++i){
-        image im1 = load_image_color(random_paths[i*2],   w, h);
-        image im2 = load_image_color(random_paths[i*2+1], w, h);
+        image im1 = load_image_color(paths[i*2],   w, h);
+        image im2 = load_image_color(paths[i*2+1], w, h);
 
         d.X.vals[i] = calloc(d.X.cols, sizeof(float));
         memcpy(d.X.vals[i],         im1.data, h*w*3*sizeof(float));
         memcpy(d.X.vals[i] + h*w*3, im2.data, h*w*3*sizeof(float));
 
-        //char *imlabel1 = find_replace(random_paths[i*2],   "imgs", "labels");
-        //char *imlabel2 = find_replace(random_paths[i*2+1], "imgs", "labels");
+        int id;
+        float iou;
+
+        char *imlabel1 = find_replace(paths[i*2],   "imgs", "labels");
+        imlabel1 = find_replace(imlabel1, "jpg", "txt");
+        FILE *fp1 = fopen(imlabel1, "r");
+
+        while(fscanf(fp1, "%d %f", &id, &iou) == 2){
+            if (d.y.vals[i][2*id] < iou) d.y.vals[i][2*id] = iou;
+        }
+
+        char *imlabel2 = find_replace(paths[i*2+1], "imgs", "labels");
+        imlabel2 = find_replace(imlabel2, "jpg", "txt");
+        FILE *fp2 = fopen(imlabel2, "r");
+
+        while(fscanf(fp2, "%d %f", &id, &iou) == 2){
+            if (d.y.vals[i][2*id + 1] < iou) d.y.vals[i][2*id + 1] = iou;
+        }
+        
+        for (j = 0; j < classes; ++j){
+            if (d.y.vals[i][2*j] > .5 &&  d.y.vals[i][2*j+1] < .5){
+                d.y.vals[i][2*j] = 1;
+                d.y.vals[i][2*j+1] = 0;
+            } else if (d.y.vals[i][2*j] < .5 &&  d.y.vals[i][2*j+1] > .5){
+                d.y.vals[i][2*j] = 0;
+                d.y.vals[i][2*j+1] = 1;
+            } else {
+                d.y.vals[i][2*j]   = SECRET_NUM;
+                d.y.vals[i][2*j+1] = SECRET_NUM;
+            }
+        }
+        fclose(fp1);
+        fclose(fp2);
 
         free_image(im1);
         free_image(im2);
     }
-    free(random_paths);
+    if(m) free(paths);
     return d;
 }
 
@@ -503,20 +543,24 @@
 
 void *load_thread(void *ptr)
 {
-    
-    #ifdef GPU
-        cudaError_t status = cudaSetDevice(gpu_index);
-        check_error(status);
-    #endif
 
-    printf("Loading data: %d\n", rand_r(&data_seed));
+#ifdef GPU
+    cudaError_t status = cudaSetDevice(gpu_index);
+    check_error(status);
+#endif
+
+    //printf("Loading data: %d\n", rand_r(&data_seed));
     load_args a = *(struct load_args*)ptr;
     if (a.type == CLASSIFICATION_DATA){
         *a.d = load_data(a.paths, a.n, a.m, a.labels, a.classes, a.w, a.h);
     } else if (a.type == DETECTION_DATA){
         *a.d = load_data_detection(a.n, a.paths, a.m, a.classes, a.w, a.h, a.num_boxes, a.background);
+    } else if (a.type == WRITING_DATA){
+        *a.d = load_data_writing(a.paths, a.n, a.m, a.w, a.h, a.out_w, a.out_h);
     } else if (a.type == REGION_DATA){
-        *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes);
+        *a.d = load_data_region(a.n, a.paths, a.m, a.w, a.h, a.num_boxes, a.classes, a.jitter);
+    } else if (a.type == COMPARE_DATA){
+        *a.d = load_data_compare(a.n, a.paths, a.m, a.classes, a.w, a.h);
     } else if (a.type == IMAGE_DATA){
         *(a.im) = load_image_color(a.path, 0, 0);
         *(a.resized) = resize_image(*(a.im), a.w, a.h);
@@ -530,20 +574,18 @@
     pthread_t thread;
     struct load_args *ptr = calloc(1, sizeof(struct load_args));
     *ptr = args;
-    if(pthread_create(&thread, 0, load_thread, ptr)) {
-        error("Thread creation failed");
-    }
+    if(pthread_create(&thread, 0, load_thread, ptr)) error("Thread creation failed");
     return thread;
 }
 
-data load_data_writing(char **paths, int n, int m, int w, int h)
+data load_data_writing(char **paths, int n, int m, int w, int h, int out_w, int out_h)
 {
     if(m) paths = get_random_paths(paths, n, m);
     char **replace_paths = find_replace_paths(paths, n, ".png", "-label.png");
     data d;
     d.shallow = 0;
     d.X = load_image_paths(paths, n, w, h);
-    d.y = load_image_paths_gray(replace_paths, n, w/8, h/8);
+    d.y = load_image_paths_gray(replace_paths, n, out_w, out_h);
     if(m) free(paths);
     int i;
     for(i = 0; i < n; ++i) free(replace_paths[i]);

--
Gitblit v1.10.0