Constantin Wenger
2019-08-10 f9d5508010c4e67e9b1af6bb8347ba2a3023fa78
generate_data.py
@@ -1,18 +1,24 @@
from glob import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
import pickle
import math
import random
import os
import re
import cv2
import fetch_data
import sys
from glob import glob
import math
import matplotlib.pyplot as plt
import matplotlib.image as mpimage
import numpy as np
import os
import pandas as pd
import pickle
import random
import transform_data
# Referenced from geaxgx's playing-card-detection: https://github.com/geaxgx/playing-card-detection
from config import Config
class Backgrounds:
    """
    Container class for all background images for generator
    Referenced from geaxgx's playing-card-detection: https://github.com/geaxgx/playing-card-detection
    """
    def __init__(self, images=None, dumps_dir='data/dtd/images'):
        if images is not None:
            self._images = images
@@ -38,8 +44,15 @@
def load_dtd(dtd_dir='data/dtd/images', dump_it=True, dump_batch_size=1000):
    """
    Load Describable Texture Dataset (DTD) from local
    :param dtd_dir: path of the DTD images folder
    :param dump_it: flag for pickling it
    :param dump_batch_size: # of images stored per pickle file
    :return: list of all DTD images
    """
    if not os.path.exists(dtd_dir):
        print('Warning: directory for DTD 5s doesn\'t exist.' % dtd_dir)
        print('Warning: directory for DTD %s doesn\'t exist.' % dtd_dir)
        print('You can download the dataset using this command:'
              '!wget https://www.robots.ox.ac.uk/~vgg/data/dtd/download/dtd-r1.0.1.tar.gz')
        return []
@@ -62,46 +75,62 @@
def apply_bounding_box(img, card_info, display=False):
    # Mana symbol - They are located on the top right side of the card, next to the name.
    # Their position is stationary, and is right-aligned.
    has_mana_cost = isinstance(card_info['mana_cost'], str)  # Cards with no mana cost will have nan
    is_planeswalker = 'Planeswalker' in card_info['type_line']
    if has_mana_cost:
        mana_cost = re.findall('\{(.*?)\}', card_info['mana_cost'])
        x2 = 683
        if is_planeswalker:
            y1 = 50
        else:
            y1 = 67
        for i in reversed(range(len(mana_cost))):
            is_hybrid = '/' in mana_cost[i]
            if is_hybrid:
                box = [(x2 - 47, y1 - 8), (x2 + 2, y1 + 43)]  # (x1, y1), (x2, y2)
                x2 -= 45
            else:
                box = [(x2 - 39, y1), (x2, y1 + 41)]  # (x1, y1), (x2, y2)
                x2 -= 37
            img_symbol = img[box[0][1]:box[1][1], box[0][0]:box[1][0]]
            if display:
                cv2.imshow('symbol', img_symbol)
                cv2.waitKey(0)
    """
    Given a card image, extract specific features that can be used to train a model.
    Note: Mana & set symbols are deprecated from the feature list. Refer to previous commits for their implementation:
    https://github.com/hj3yoo/mtg_card_detector/tree/bb34d4e13da0f4753fbdefee837f54b16149d3ef
    :param img: image of the card
    :param card_info: characteristics of this card
    :param display: flag for displaying the extracted features
    :return:
    """
    # List of detected objects to be fed into the neural net
    # The first object is the entire card
    detected_object_list = [transform_data.ExtractedObject('card', [(0, 0), (len(img[0]), 0), (len(img[0]), len(img)),
                                                                    (0, len(img))])]
    return detected_object_list
def main():
    random.seed()
    #bg_images = load_dtd()
    #bg = Backgrounds()
    #bg.get_random(display=True)
    df = fetch_data.load_all_cards_text('data/all_cards.csv')
    #repeat = 'y'
    while True:
        rand_card = df.iloc[random.randint(0, df.shape[0] - 1)]
        card_img = cv2.imread('data/png/%s/%s_%s.png' % (rand_card['set'], rand_card['collector_number'],
                                                         fetch_data.get_valid_filename(rand_card['name'])))
        print(rand_card['name'])
        sys.stdout.flush()
        apply_bounding_box(card_img, rand_card, display=True)
        #repeat = input('y to repeat, n to finish')
    card_pool = pd.DataFrame()
    for set_name in Config.all_set_list:
        df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (Config.data_dir, set_name))
        #for _ in range(3):
        #    card_info = df.iloc[random.randint(0, df.shape[0] - 1)]
        #    # Currently ignoring planeswalker cards due to their different card layout
        #    is_planeswalker = 'Planeswalker' in card_info['type_line']
        #    if not is_planeswalker:
        #        card_pool = card_pool.append(card_info)
        card_pool = card_pool.append(df)
    '''
    print(card_pool)
    mana_symbol_set = set()
    for _, card_info in card_pool.iterrows():
        has_mana_cost = isinstance(card_info['mana_cost'], str)
        if has_mana_cost:
            mana_cost = re.findall('\{(.*?)\}', card_info['mana_cost'])
            for symbol in mana_cost:
                mana_symbol_set.add(symbol)
    print(mana_symbol_set)
    '''
    for _, card_info in card_pool.iterrows():
        img_name = '%s/card_img/png/%s/%s_%s.png' % (Config.data_dir, card_info['set'], card_info['collector_number'],
                                            fetch_data.get_valid_filename(card_info['name']))
        print(img_name)
        card_img = cv2.imread(img_name)
        if card_img is None:
            fetch_data.fetch_card_image(card_info, out_dir='../usb/data/png/%s' % card_info['set'])
            card_img = cv2.imread(img_name)
        detected_object_list = apply_bounding_box(card_img, card_info, display=True)
        print(detected_object_list)
    return