From 23d94e4846bf4ec13069703a28b1d776f4bbe44f Mon Sep 17 00:00:00 2001
From: Edmond Yoo <hj3yoo@uwaterloo.ca>
Date: Sat, 13 Oct 2018 18:49:47 +0000
Subject: [PATCH] Cleaning & commenting #3 - refactoring constants to Config class
---
transform_data.py | 43 ++++++++++++++++++++++---------------------
1 files changed, 22 insertions(+), 21 deletions(-)
diff --git a/transform_data.py b/transform_data.py
index 6b5f477..8e681a7 100644
--- a/transform_data.py
+++ b/transform_data.py
@@ -1,20 +1,18 @@
-import os
-import random
-import math
import cv2
-import numpy as np
-import imutils
-import pandas as pd
-import fetch_data
-import generate_data
-from shapely import geometry
import imgaug as ia
from imgaug import augmenters as iaa
from imgaug import parameters as iap
+import imutils
+import math
+import numpy as np
+import os
+import pandas as pd
+import random
+from shapely import geometry
-card_mask = cv2.imread('data/mask.png')
-data_dir = os.path.abspath('/media/win10/data')
-darknet_dir = os.path.abspath('.')
+import fetch_data
+import generate_data
+from config import Config
def key_pts_to_yolo(key_pts, w_img, h_img):
@@ -104,6 +102,7 @@
"""
self.check_visibility(visibility=visibility)
img_result = np.zeros((self.height, self.width, 3), dtype=np.uint8)
+ card_mask = cv2.imread(Config.card_mask_path)
for card in self.cards:
card_x = int(card.x + 0.5)
@@ -494,15 +493,15 @@
random.seed()
ia.seed(random.randrange(10000))
- bg_images = generate_data.load_dtd(dtd_dir='%s/dtd/images' % data_dir, dump_it=False)
+ bg_images = generate_data.load_dtd(dtd_dir='%s/dtd/images' % Config.data_dir, dump_it=False)
background = generate_data.Backgrounds(images=bg_images)
card_pool = pd.DataFrame()
- for set_name in fetch_data.all_set_list:
- df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
+ for set_name in Config.all_set_list:
+ df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (Config.data_dir, set_name))
card_pool = card_pool.append(df)
class_ids = {}
- with open('%s/obj.names' % data_dir) as names_file:
+ with open('%s/obj.names' % Config.data_dir) as names_file:
class_name_list = names_file.read().splitlines()
for i in range(len(class_name_list)):
class_ids[class_name_list[i]] = i
@@ -522,12 +521,14 @@
# Use 2 to 5 cards per generator
for _, card_info in card_pool.sample(random.randint(2, 5)).iterrows():
- img_name = '%s/card_img/png/%s/%s_%s.png' % (data_dir, card_info['set'], card_info['collector_number'],
+ img_name = '%s/card_img/png/%s/%s_%s.png' % (Config.data_dir, card_info['set'],
+ card_info['collector_number'],
fetch_data.get_valid_filename(card_info['name']))
out_name += '%s%s_' % (card_info['set'], card_info['collector_number'])
card_img = cv2.imread(img_name)
if card_img is None:
- fetch_data.fetch_card_image(card_info, out_dir='%s/card_img/png/%s' % (data_dir, card_info['set']))
+ fetch_data.fetch_card_image(card_info, out_dir='%s/card_img/png/%s' % (Config.data_dir,
+ card_info['set']))
card_img = cv2.imread(img_name)
if card_img is None:
print('WARNING: card %s is not found!' % img_name)
@@ -547,15 +548,15 @@
if i % 3 == 0:
generator.generate_non_obstructive()
generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_update/%s%d'
- % (data_dir, out_name, j), aug=seq)
+ % (Config.data_dir, out_name, j), aug=seq)
elif i % 3 == 1:
generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_update/%s%d'
- % (data_dir, out_name, j), aug=seq)
+ % (Config.data_dir, out_name, j), aug=seq)
else:
generator.generate_vertical_span(theta=random.uniform(-math.pi, math.pi))
generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_update/%s%d'
- % (data_dir, out_name, j), aug=seq)
+ % (Config.data_dir, out_name, j), aug=seq)
#generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
#generator.render(display=True, aug=seq, debug=True)
--
Gitblit v1.10.0