From 989ab8c38a02fa7ea9c25108151736c62e81c972 Mon Sep 17 00:00:00 2001
From: Joseph Redmon <pjreddie@gmail.com>
Date: Fri, 24 Apr 2015 17:27:50 +0000
Subject: [PATCH] IOU loss function
---
src/data.c | 45 +++++++++++++++++++++++++++++++++++----------
1 files changed, 35 insertions(+), 10 deletions(-)
diff --git a/src/data.c b/src/data.c
index 012d7cf..f1f5b80 100644
--- a/src/data.c
+++ b/src/data.c
@@ -65,22 +65,22 @@
return X;
}
-typedef struct box{
+typedef struct{
int id;
float x,y,w,h;
float left, right, top, bottom;
-} box;
+} box_label;
-box *read_boxes(char *filename, int *n)
+box_label *read_boxes(char *filename, int *n)
{
- box *boxes = calloc(1, sizeof(box));
+ box_label *boxes = calloc(1, sizeof(box_label));
FILE *file = fopen(filename, "r");
if(!file) file_error(filename);
float x, y, h, w;
int id;
int count = 0;
while(fscanf(file, "%d %f %f %f %f", &id, &x, &y, &w, &h) == 5){
- boxes = realloc(boxes, (count+1)*sizeof(box));
+ boxes = realloc(boxes, (count+1)*sizeof(box_label));
boxes[count].id = id;
boxes[count].x = x;
boxes[count].y = y;
@@ -97,11 +97,11 @@
return boxes;
}
-void randomize_boxes(box *b, int n)
+void randomize_boxes(box_label *b, int n)
{
int i;
for(i = 0; i < n; ++i){
- box swap = b[i];
+ box_label swap = b[i];
int index = rand_r(&data_seed)%n;
b[i] = b[index];
b[index] = swap;
@@ -114,7 +114,7 @@
labelpath = find_replace(labelpath, ".jpg", ".txt");
labelpath = find_replace(labelpath, ".JPEG", ".txt");
int count = 0;
- box *boxes = read_boxes(labelpath, &count);
+ box_label *boxes = read_boxes(labelpath, &count);
randomize_boxes(boxes, count);
float x,y,w,h;
float left, top, right, bot;
@@ -174,10 +174,10 @@
if(background) truth[index++] = 0;
truth[index+id] = 1;
index += classes;
- truth[index++] = y;
truth[index++] = x;
- truth[index++] = h;
+ truth[index++] = y;
truth[index++] = w;
+ truth[index++] = h;
}
free(boxes);
}
@@ -408,6 +408,31 @@
return thread;
}
+matrix concat_matrix(matrix m1, matrix m2)
+{
+ int i, count = 0;
+ matrix m;
+ m.cols = m1.cols;
+ m.rows = m1.rows+m2.rows;
+ m.vals = calloc(m1.rows + m2.rows, sizeof(float*));
+ for(i = 0; i < m1.rows; ++i){
+ m.vals[count++] = m1.vals[i];
+ }
+ for(i = 0; i < m2.rows; ++i){
+ m.vals[count++] = m2.vals[i];
+ }
+ return m;
+}
+
+data concat_data(data d1, data d2)
+{
+ data d;
+ d.shallow = 1;
+ d.X = concat_matrix(d1.X, d2.X);
+ d.y = concat_matrix(d1.y, d2.y);
+ return d;
+}
+
data load_categorical_data_csv(char *filename, int target, int k)
{
data d;
--
Gitblit v1.10.0