From 0d6bb5d44d8e815ebf6ccce1dae2f83178780e7b Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 03 Dec 2013 00:41:40 +0000
Subject: [PATCH] Working?

---
 src/data.c |   31 +++++++++++++++++++------------
 1 files changed, 19 insertions(+), 12 deletions(-)

diff --git a/src/data.c b/src/data.c
index 7ef0d80..9e5791f 100644
--- a/src/data.c
+++ b/src/data.c
@@ -30,13 +30,18 @@
     return lines;
 }
 
-int get_truth(char *path)
+void fill_truth(char *path, char **labels, int k, double *truth)
 {
-    if(strstr(path, "dog")) return 1;
-    return 0;
+    int i;
+    memset(truth, 0, k*sizeof(double));
+    for(i = 0; i < k; ++i){
+        if(strstr(path, labels[i])){
+            truth[i] = 1;
+        }
+    }
 }
 
-batch load_list(list *paths)
+batch load_list(list *paths, char **labels, int k)
 {
     char *path;
     batch data = make_batch(paths->size, 2);
@@ -45,16 +50,16 @@
     for(i = 0; i < data.n; ++i){
         path = (char *)n->val;
         data.images[i] = load_image(path);
-        data.truth[i][0] = get_truth(path);
+        fill_truth(path, labels, k, data.truth[i]);
         n = n->next;
     }
     return data;
 }
 
-batch get_all_data(char *filename)
+batch get_all_data(char *filename, char **labels, int k)
 {
     list *paths = get_paths(filename);
-    batch b = load_list(paths);
+    batch b = load_list(paths, labels, k);
     free_list_contents(paths);
     free_list(paths);
     return b;
@@ -71,7 +76,7 @@
     free(b.truth);
 }
 
-batch get_batch(char *filename, int curr, int total)
+batch get_batch(char *filename, int curr, int total, char **labels, int k)
 {
     list *plist = get_paths(filename);
     char **paths = (char **)list_to_array(plist);
@@ -81,7 +86,7 @@
     batch b = make_batch(end-start, 2);
     for(i = start; i < end; ++i){
         b.images[i-start] = load_image(paths[i]);
-        b.truth[i-start][0] = get_truth(paths[i]);
+        fill_truth(paths[i], labels, k, b.truth[i-start]);
     }
     free_list_contents(plist);
     free_list(plist);
@@ -89,7 +94,7 @@
     return b;
 }
 
-batch random_batch(char *filename, int n)
+batch random_batch(char *filename, int n, char **labels, int k)
 {
     list *plist = get_paths(filename);
     char **paths = (char **)list_to_array(plist);
@@ -98,8 +103,10 @@
     for(i = 0; i < n; ++i){
         int index = rand()%plist->size;
         b.images[i] = load_image(paths[index]);
-        normalize_image(b.images[i]);
-        b.truth[i][0] = get_truth(paths[index]);
+        //scale_image(b.images[i], 1./255.);
+        z_normalize_image(b.images[i]);
+        fill_truth(paths[index], labels, k, b.truth[i]);
+        //print_image(b.images[i]);
     }
     free_list_contents(plist);
     free_list(plist);

--
Gitblit v1.10.0