From dea64611730c84a59c711c61f7f80948f82bcd31 Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Fri, 12 Oct 2018 20:12:47 +0000
Subject: [PATCH] Commit before removing YOLO
---
transform_data.py | 21 ++++++++++-----------
1 files changed, 10 insertions(+), 11 deletions(-)
diff --git a/transform_data.py b/transform_data.py
index dcfaed9..b22084d 100644
--- a/transform_data.py
+++ b/transform_data.py
@@ -15,7 +15,7 @@
card_mask = cv2.imread('data/mask.png')
data_dir = os.path.abspath('/media/win10/data')
-darknet_dir = os.path.abspath('darknet')
+darknet_dir = os.path.abspath('.')
def key_pts_to_yolo(key_pts, w_img, h_img):
@@ -347,7 +347,8 @@
coords_in_gen = [card.coordinate_in_generator(key_pt[0], key_pt[1]) for key_pt in ext_obj.key_pts]
obj_yolo_info = key_pts_to_yolo(coords_in_gen, self.width, self.height)
if ext_obj.label == 'card':
- class_id = self.class_ids[card.info['name']]
+ #class_id = self.class_ids[card.info['name']]
+ class_id = 0
out_txt.write(str(class_id) + ' %.6f %.6f %.6f %.6f\n' % obj_yolo_info)
pass
elif ext_obj.label[:ext_obj.label.find[':']] == 'mana_symbol':
@@ -499,17 +500,15 @@
#bg_images = [cv2.imread('data/frilly_0007.jpg')]
background = generate_data.Backgrounds(images=bg_images)
- #card_pool = pd.DataFrame()
- #for set_name in fetch_data.all_set_list:
- # df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
- # card_pool = card_pool.append(df)
- card_pool = fetch_data.load_all_cards_text('%s/csv/custom.csv' % data_dir)
+ card_pool = pd.DataFrame()
+ for set_name in fetch_data.all_set_list:
+ df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
+ card_pool = card_pool.append(df)
class_ids = {}
with open('%s/obj.names' % data_dir) as names_file:
class_name_list = names_file.read().splitlines()
for i in range(len(class_name_list)):
class_ids[class_name_list[i]] = i
- print(class_ids)
num_gen = 60000
num_iter = 1
@@ -545,15 +544,15 @@
if i % 3 == 0:
generator.generate_non_obstructive()
- generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_custom/%s_%d'
+ generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_update/%s%d'
% (data_dir, out_name, j), aug=seq)
elif i % 3 == 1:
generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
- generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_custom/%s_%d'
+ generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_update/%s%d'
% (data_dir, out_name, j), aug=seq)
else:
generator.generate_vertical_span(theta=random.uniform(-math.pi, math.pi))
- generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_custom/%s_%d'
+ generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_update/%s%d'
% (data_dir, out_name, j), aug=seq)
#generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
--
Gitblit v1.10.0