From dea64611730c84a59c711c61f7f80948f82bcd31 Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Fri, 12 Oct 2018 20:12:47 +0000
Subject: [PATCH] Commit before removing YOLO

---
 transform_data.py |   21 ++++++++++-----------
 1 files changed, 10 insertions(+), 11 deletions(-)

diff --git a/transform_data.py b/transform_data.py
index dcfaed9..b22084d 100644
--- a/transform_data.py
+++ b/transform_data.py
@@ -15,7 +15,7 @@
 
 card_mask = cv2.imread('data/mask.png')
 data_dir = os.path.abspath('/media/win10/data')
-darknet_dir = os.path.abspath('darknet')
+darknet_dir = os.path.abspath('.')
 
 
 def key_pts_to_yolo(key_pts, w_img, h_img):
@@ -347,7 +347,8 @@
                 coords_in_gen = [card.coordinate_in_generator(key_pt[0], key_pt[1]) for key_pt in ext_obj.key_pts]
                 obj_yolo_info = key_pts_to_yolo(coords_in_gen, self.width, self.height)
                 if ext_obj.label == 'card':
-                    class_id = self.class_ids[card.info['name']]
+                    #class_id = self.class_ids[card.info['name']]
+                    class_id = 0
                     out_txt.write(str(class_id) + ' %.6f %.6f %.6f %.6f\n' % obj_yolo_info)
                     pass
                 elif ext_obj.label[:ext_obj.label.find[':']] == 'mana_symbol':
@@ -499,17 +500,15 @@
     #bg_images = [cv2.imread('data/frilly_0007.jpg')]
     background = generate_data.Backgrounds(images=bg_images)
 
-    #card_pool = pd.DataFrame()
-    #for set_name in fetch_data.all_set_list:
-    #    df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
-    #    card_pool = card_pool.append(df)
-    card_pool = fetch_data.load_all_cards_text('%s/csv/custom.csv' % data_dir)
+    card_pool = pd.DataFrame()
+    for set_name in fetch_data.all_set_list:
+        df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
+        card_pool = card_pool.append(df)
     class_ids = {}
     with open('%s/obj.names' % data_dir) as names_file:
         class_name_list = names_file.read().splitlines()
         for i in range(len(class_name_list)):
             class_ids[class_name_list[i]] = i
-    print(class_ids)
 
     num_gen = 60000
     num_iter = 1
@@ -545,15 +544,15 @@
 
             if i % 3 == 0:
                 generator.generate_non_obstructive()
-                generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_custom/%s_%d'
+                generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_update/%s%d'
                                                                         % (data_dir, out_name, j), aug=seq)
             elif i % 3 == 1:
                 generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
-                generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_custom/%s_%d'
+                generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_update/%s%d'
                                                                         % (data_dir, out_name, j), aug=seq)
             else:
                 generator.generate_vertical_span(theta=random.uniform(-math.pi, math.pi))
-                generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_custom/%s_%d'
+                generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_update/%s%d'
                                                                         % (data_dir, out_name, j), aug=seq)
 
             #generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))

--
Gitblit v1.10.0