From 0d6bb5d44d8e815ebf6ccce1dae2f83178780e7b Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Tue, 03 Dec 2013 00:41:40 +0000
Subject: [PATCH] Working?
---
src/data.c | 31 +++++++++++++++++++------------
1 files changed, 19 insertions(+), 12 deletions(-)
diff --git a/src/data.c b/src/data.c
index 7ef0d80..9e5791f 100644
--- a/src/data.c
+++ b/src/data.c
@@ -30,13 +30,18 @@
return lines;
}
-int get_truth(char *path)
+void fill_truth(char *path, char **labels, int k, double *truth)
{
- if(strstr(path, "dog")) return 1;
- return 0;
+ int i;
+ memset(truth, 0, k*sizeof(double));
+ for(i = 0; i < k; ++i){
+ if(strstr(path, labels[i])){
+ truth[i] = 1;
+ }
+ }
}
-batch load_list(list *paths)
+batch load_list(list *paths, char **labels, int k)
{
char *path;
batch data = make_batch(paths->size, 2);
@@ -45,16 +50,16 @@
for(i = 0; i < data.n; ++i){
path = (char *)n->val;
data.images[i] = load_image(path);
- data.truth[i][0] = get_truth(path);
+ fill_truth(path, labels, k, data.truth[i]);
n = n->next;
}
return data;
}
-batch get_all_data(char *filename)
+batch get_all_data(char *filename, char **labels, int k)
{
list *paths = get_paths(filename);
- batch b = load_list(paths);
+ batch b = load_list(paths, labels, k);
free_list_contents(paths);
free_list(paths);
return b;
@@ -71,7 +76,7 @@
free(b.truth);
}
-batch get_batch(char *filename, int curr, int total)
+batch get_batch(char *filename, int curr, int total, char **labels, int k)
{
list *plist = get_paths(filename);
char **paths = (char **)list_to_array(plist);
@@ -81,7 +86,7 @@
batch b = make_batch(end-start, 2);
for(i = start; i < end; ++i){
b.images[i-start] = load_image(paths[i]);
- b.truth[i-start][0] = get_truth(paths[i]);
+ fill_truth(paths[i], labels, k, b.truth[i-start]);
}
free_list_contents(plist);
free_list(plist);
@@ -89,7 +94,7 @@
return b;
}
-batch random_batch(char *filename, int n)
+batch random_batch(char *filename, int n, char **labels, int k)
{
list *plist = get_paths(filename);
char **paths = (char **)list_to_array(plist);
@@ -98,8 +103,10 @@
for(i = 0; i < n; ++i){
int index = rand()%plist->size;
b.images[i] = load_image(paths[index]);
- normalize_image(b.images[i]);
- b.truth[i][0] = get_truth(paths[index]);
+ //scale_image(b.images[i], 1./255.);
+ z_normalize_image(b.images[i]);
+ fill_truth(paths[index], labels, k, b.truth[i]);
+ //print_image(b.images[i]);
}
free_list_contents(plist);
free_list(plist);
--
Gitblit v1.10.0