Edmond Yoo
2018-09-17 504ece5b00f192d5c1b343fd06ce1648f9139180
Code cleaning & training new YOLO model
8 files modified
1 files added
107535 ■■■■■ changed files
README.md 2 ●●● patch | view | raw | blame | history
cfg/tiny_yolo.cfg 8 ●●●● patch | view | raw | blame | history
cfg/tiny_yolo_old.cfg 134 ●●●●● patch | view | raw | blame | history
fetch_data.py 63 ●●●●● patch | view | raw | blame | history
opencv_dnn.py 222 ●●●● patch | view | raw | blame | history
setup_train.py 6 ●●●● patch | view | raw | blame | history
test.txt 10707 ●●●●● patch | view | raw | blame | history
train.txt 96372 ●●●●● patch | view | raw | blame | history
transform_data.py 21 ●●●● patch | view | raw | blame | history
README.md
@@ -97,7 +97,7 @@
I've made a quick openCV algorithm to extract cards from the image, and it works decently well:
<img src="https://github.com/hj3yoo/darknet/blob/master/figures/4_detection_result_5.png" width="360">
<img src="https://github.com/hj3yoo/darknet/blob/master/figures/4_detection_result_5.jpg" width="360">
At the moment, it's fairly limited - the entire card must be shown without obstruction nor cropping, otherwise it won't detect at all.
cfg/tiny_yolo.cfg
@@ -12,7 +12,7 @@
hue=.1
learning_rate=0.001
max_batches = 30000
max_batches = 50000
policy=steps
steps=-1,100,80000,100000
scales=.1,10,.1,.1
@@ -111,15 +111,15 @@
size=1
stride=1
pad=1
filters=30
filters=54
activation=linear
[region]
anchors = 118.3429,137.0897, 95.8160,181.9724, 140.4955,166.7423, 112.7262,220.6808, 129.2741,198.9876
anchors = 118.5833,137.6176, 95.9314,181.2641, 140.8155,166.6721, 113.2913,220.6833, 128.8767,197.7804, 158.9573,196.7482, 138.7983,242.8474, 167.2494,227.1953, 165.0362,253.8635
bias_match=1
classes=1
coords=4
num=5
num=9
softmax=1
jitter=.2
rescore=1
cfg/tiny_yolo_old.cfg
New file
@@ -0,0 +1,134 @@
[net]
batch=64
subdivisions=4
width=416
height=416
channels=3
momentum=0.9
decay=0.0005
angle=0
saturation = 1.5
exposure = 1.5
hue=.1
learning_rate=0.001
max_batches = 50000
policy=steps
steps=-1,100,80000,100000
scales=.1,10,.1,.1
[convolutional]
batch_normalize=1
filters=16
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=32
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=64
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=128
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=256
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=2
[convolutional]
batch_normalize=1
filters=512
size=3
stride=1
pad=1
activation=leaky
[maxpool]
size=2
stride=1
[convolutional]
batch_normalize=1
filters=1024
size=3
stride=1
pad=1
activation=leaky
###########
[convolutional]
batch_normalize=1
size=3
stride=1
pad=1
filters=1024
activation=leaky
[convolutional]
size=1
stride=1
pad=1
filters=30
activation=linear
[region]
anchors = 118.3429,137.0897, 95.8160,181.9724, 140.4955,166.7423, 112.7262,220.6808, 129.2741,198.9876
bias_match=1
classes=1
coords=4
num=5
softmax=1
jitter=.2
rescore=1
object_scale=5
noobject_scale=1
class_scale=1
coord_scale=1
absolute=1
thresh = .6
random=0
fetch_data.py
@@ -1,4 +1,4 @@
from urllib import request
from urllib import request, error
import ast
import json
import pandas as pd
@@ -7,10 +7,20 @@
import transform_data
import time
all_set_list = [
                'mrd', 'dst', '5dn', 'chk', 'bok', 'sok', 'rav', 'gpt', 'dis', 'csp', 'tsp', 'plc', 'fut',
                '10e', 'lrw', 'mor', 'shm', 'eve', 'ala', 'con', 'arb', 'm10', 'zen', 'wwk', 'roe', 'm11', 'som', 'mbs',
                'nph', 'm12', 'isd', 'dka', 'avr', 'm13', 'rtr', 'gtc', 'dgm', 'm14', 'ths', 'bng', 'jou']
all_set_list = [# Core & expansion sets with 2003 frame
                'mrd', 'dst', '5dn', 'chk', 'bok', 'sok', 'rav', 'gpt', 'dis', 'csp', 'tsp', 'plc', 'fut', '10e', 'lrw',
                'mor', 'shm', 'eve', 'ala', 'con', 'arb', 'm10', 'zen', 'wwk', 'roe', 'm11', 'som', 'mbs', 'nph', 'm12',
                'isd', 'dka', 'avr', 'm13', 'rtr', 'gtc', 'dgm', 'm14', 'ths', 'bng', 'jou',
                # Core & expansion sets with 2015 frame
                'm15', 'ktk', 'frf', 'dtk', 'bfz', 'ogw', 'soi', 'emn', 'kld', 'aer', 'akh', 'hou', 'xln', 'rix', 'dom',
                # Box sets
                'evg', 'drb', 'dd2', 'ddc', 'td0', 'v09', 'ddd', 'h09', 'dde', 'dpa', 'v10', 'ddf', 'td0', 'pd2', 'ddg',
                'cmd', 'v11', 'ddh', 'pd3', 'ddi', 'v12', 'ddj', 'cm1', 'td2', 'ddk', 'v13', 'ddl', 'c13', 'ddm', 'md1',
                'v14', 'ddn', 'c14', 'ddo', 'v15', 'ddp', 'c15', 'ddq', 'v16', 'ddr', 'c16', 'pca', 'dds', 'cma', 'c17',
                'ddt', 'v17', 'ddu', 'cm2', 'ss1', 'gs1', 'c18',
                # Supplemental sets
                'HOP', 'ARC', 'PC2', 'CNS', 'CN2', 'E01', 'E02', 'BBD'
                ]
def fetch_all_cards_text(url='https://api.scryfall.com/cards/search?q=layout:normal+format:modern+lang:en+frame:2003',
@@ -32,9 +42,9 @@
    df = pd.DataFrame.from_dict(cards)
    if csv_name != '':
        df = df[['artist', 'border_color', 'collector_number', 'color_identity', 'colors', 'flavor_text', 'image_uris',
                 'mana_cost', 'legalities', 'name', 'oracle_text', 'rarity', 'type_line', 'set', 'set_name', 'power',
                 'toughness']]
        #df = df[['artist', 'border_color', 'collector_number', 'color_identity', 'colors', 'flavor_text', 'image_uris',
        #         'mana_cost', 'legalities', 'name', 'oracle_text', 'rarity', 'type_line', 'set', 'set_name', 'power',
        #         'toughness']]
        #df.to_json(csv_name)
        df.to_csv(csv_name, sep=';')  # Comma doesn't work, since some columns are saved as a dict
@@ -72,18 +82,33 @@
def fetch_card_image(row, out_dir='', size='png'):
    if isinstance(row['image_uris'], str):  # For some reason, dict isn't being parsed in the previous step
        png_url = ast.literal_eval(row['image_uris'])[size]
    else:
        png_url = row['image_uris'][size]
    if out_dir == '':
        out_dir = 'data/%s/%s' % (size, row['set'])
        out_dir = '%s/card_img/%s/%s' % (transform_data.data_dir, size, row['set'])
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    img_name = '%s/%s_%s.png' % (out_dir, row['collector_number'], get_valid_filename(row['name']))
    if not os.path.isfile(img_name):
        request.urlretrieve(png_url, filename=img_name)
        print(img_name)
    png_urls = []
    card_names = []
    if row['layout'] == 'transform' or row['layout'] == 'double_faced_token':
        if isinstance(row['card_faces'], str):  # For some reason, dict isn't being parsed in the previous step
            card_faces = ast.literal_eval(row['card_faces'])
        else:
            card_faces = row['card_faces']
        for i in range(len(card_faces)):
            png_urls.append(card_faces[i]['image_uris'][size])
            card_names.append(get_valid_filename(card_faces[i]['name']))
    else: #if row['layout'] == 'normal':
        if isinstance(row['image_uris'], str):  # For some reason, dict isn't being parsed in the previous step
            png_urls.append(ast.literal_eval(row['image_uris'])[size])
        else:
            png_urls.append(row['image_uris'][size])
        card_names.append(get_valid_filename(row['name']))
    for i in range(len(png_urls)):
        img_name = '%s/%s_%s.png' % (out_dir, row['collector_number'], card_names[i])
        if not os.path.isfile(img_name):
            request.urlretrieve(png_urls[i], filename=img_name)
            print(img_name)
def main():
@@ -95,8 +120,8 @@
                                          % set_name, csv_name=csv_name)
        else:
            df = load_all_cards_text(csv_name)
        time.sleep(1)
        #fetch_all_cards_image(df, out_dir='../usb/data/png/%s' % set_name)
        df.sort_values('collector_number')
        fetch_all_cards_image(df, out_dir='%s/card_img/png/%s' % (transform_data.data_dir, set_name))
    #df = fetch_all_cards_text(url='https://api.scryfall.com/cards/search?q=layout:normal+lang:en+frame:2003',
    #                          csv_name='data/csv/all.csv')
    pass
opencv_dnn.py
@@ -6,6 +6,7 @@
import sys
import math
import random
import time
from PIL import Image
import fetch_data
import transform_data
@@ -29,7 +30,7 @@
            card_img = cv2.imread(img_name)
        if card_img is None:
            print('WARNING: card %s is not found!' % img_name)
        img_art = Image.fromarray(card_img[121:580, 63:685])
        img_art = Image.fromarray(card_img[121:580, 63:685])  # For 745*1040 size card image
        art_hash = ih.phash(img_art, hash_size=32, highfreq_factor=4)
        card_pool.at[ind, 'art_hash'] = art_hash
        img_card = Image.fromarray(card_img)
@@ -42,27 +43,6 @@
        card_pool.to_pickle(save_to)
    return card_pool
'''
df_list = []
for set_name in fetch_data.all_set_list:
    csv_name = '%s/csv/%s.csv' % (transform_data.data_dir, set_name)
    df = fetch_data.load_all_cards_text(csv_name)
    df_list.append(df)
    #print(df)
card_pool = pd.concat(df_list)
card_pool.reset_index(drop=True, inplace=True)
card_pool.drop('Unnamed: 0', axis=1, inplace=True, errors='ignore')
card_pool = calc_image_hashes(card_pool, save_to='card_pool.pck')
'''
#csv_name = '%s/csv/%s.csv' % (transform_data.data_dir, 'rtr')
#card_pool = fetch_data.load_all_cards_text(csv_name)
#card_pool = calc_image_hashes(card_pool)
card_pool = pd.read_pickle('card_pool.pck')
# Disclaimer: majority of the basic framework in this file is modified from the following tutorial:
# https://www.learnopencv.com/deep-learning-based-object-detection-using-yolov3-with-opencv-python-c/
# www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/
def order_points(pts):
@@ -89,6 +69,7 @@
    return rect
# www.pyimagesearch.com/2014/08/25/4-point-opencv-getperspective-transform-example/
def four_point_transform(image, pts):
    # obtain a consistent order of the points and unpack them
    # individually
@@ -121,14 +102,14 @@
        [0, maxHeight - 1]], dtype="float32")
    # compute the perspective transform matrix and then apply it
    M = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, M, (maxWidth, maxHeight))
    mat = cv2.getPerspectiveTransform(rect, dst)
    warped = cv2.warpPerspective(image, mat, (maxWidth, maxHeight))
    # If the image is horizontally long, rotate it by 90
    if maxWidth > maxHeight:
        center = (maxHeight / 2, maxHeight / 2)
        M_rot = cv2.getRotationMatrix2D(center, 270, 1.0)
        warped = cv2.warpAffine(warped, M_rot, (maxHeight, maxWidth))
        mat_rot = cv2.getRotationMatrix2D(center, 270, 1.0)
        warped = cv2.warpAffine(warped, mat_rot, (maxHeight, maxWidth))
    # return the warped image
    return warped
@@ -143,11 +124,11 @@
# Remove the bounding boxes with low confidence using non-maxima suppression
# https://www.learnopencv.com/deep-learning-based-object-detection-using-yolov3-with-opencv-python-c/
def post_process(frame, outs, thresh_conf, thresh_nms):
    frame_height = frame.shape[0]
    frame_width = frame.shape[1]
    # Scan through all the bounding boxes output from the network and keep only the
    # ones with high confidence scores. Assign the box's class label as the class with the highest score.
    class_ids = []
@@ -159,6 +140,7 @@
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            if confidence > thresh_conf:
                #print(detection[0:3])
                center_x = int(detection[0] * frame_width)
                center_y = int(detection[1] * frame_height)
                width = int(detection[2] * frame_width)
@@ -221,7 +203,7 @@
    return corrected
def find_card(img, thresh_c=5, kernel_size=(3, 3), size_ratio=0.3):
def find_card(img, thresh_c=5, kernel_size=(3, 3), size_ratio=0.2):
    # Typical pre-processing - grayscale, blurring, thresholding
    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img_blur = cv2.medianBlur(img_gray, 5)
@@ -255,7 +237,8 @@
    return cnts_rect
def detect_frame(net, classes, img, thresh_conf=0.5, thresh_nms=0.4, in_dim=(416, 416), display=True, out_path=None):
def detect_frame(net, classes, img, thresh_conf=0.1, thresh_nms=0.4, in_dim=(416, 416), out_path=None, display=True,
                 debug=False):
    img_copy = img.copy()
    # Create a 4D blob from a frame.
    blob = cv2.dnn.blobFromImage(img, 1 / 255, in_dim, [0, 0, 0], 1, crop=False)
@@ -266,125 +249,107 @@
    # Runs the forward pass to get output of the output layers
    outs = net.forward(get_outputs_names(net))
    img_result = img.copy()
    # Remove the bounding boxes with low confidence
    obj_list = post_process(img, outs, thresh_conf, thresh_nms)
    for obj in obj_list:
        class_id, confidence, box = obj
        left, top, width, height = box
        draw_pred(img, class_id, classes, confidence, left, top, left + width, top + height)
        draw_pred(img_result, class_id, classes, confidence, left, top, left + width, top + height)
    # Put efficiency information. The function getPerfProfile returns the
    # overall time for inference(t) and the timings for each of the layers(in layersTimes)
    t, _ = net.getPerfProfile()
    label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())
    cv2.putText(img, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255))
    #if display:
    #    t, _ = net.getPerfProfile()
    #    label = 'Inference time: %.2f ms' % (t * 1000.0 / cv2.getTickFrequency())
    #    cv2.putText(img_result, label, (0, 15), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255))
    '''
    Assuming that the model has properly identified all cards, there should be 1 card that can be classified per
    bounding box. Find the largest rectangular contour from the region of interest, and identify the card by
    comparing the perceptual hashing of the image with the other cards' image from the database.
    '''
    card_name_list = []
    for i in range(len(obj_list)):
        _, _, box = obj_list[i]
        left, top, width, height = box
        # Just in case the bounding box trimmed the edge of the cards, give it a bit of offset around the edge
        offset_ratio = 0.1
        x1 = max(0, int(left - offset_ratio * width))
        x2 = min(img.shape[1], int(left + (1 + offset_ratio) * width))
        y1 = max(0, int(top - offset_ratio * height))
        y2 = min(img.shape[0], int(top + (1 + offset_ratio) * height))
        img_snip = img[y1:y2, x1:x2]
        cnts = find_card(img_snip)
        if len(cnts) > 0:
            cnt = cnts[0]  # The largest (rectangular) contour
            pts = np.float32([p[0] for p in cnt])
            img_warp = four_point_transform(img_snip, pts)
            img_warp = cv2.resize(img_warp, (card_width, card_height))
            '''
            img_art = img_warp[47:249, 22:294]
            img_art = Image.fromarray(img_art.astype('uint8'), 'RGB')
            art_hash = ih.phash(img_art, hash_size=32, highfreq_factor=4)
            card_pool['hash_diff'] = card_pool['art_hash'] - art_hash
            min_cards = card_pool[card_pool['hash_diff'] == min(card_pool['hash_diff'])]
            card_name = min_cards.iloc[0]['name']
            '''
            img_card = Image.fromarray(img_warp.astype('uint8'), 'RGB')
            card_hash = ih.phash(img_card, hash_size=32, highfreq_factor=4)
            card_pool['hash_diff'] = card_pool['card_hash'] - card_hash
            min_cards = card_pool[card_pool['hash_diff'] == min(card_pool['hash_diff'])]
            card_name = min_cards.iloc[0]['name']
            card_name_list.append(card_name)
            hash_diff = min_cards.iloc[0]['hash_diff']
            # Display the result
            if debug:
                # cv2.rectangle(img_warp, (22, 47), (294, 249), (0, 255, 0), 2)
                cv2.putText(img_warp, card_name + ', ' + str(hash_diff), (0, 50),
                            cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(img_result, card_name , (x1, y1), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
            if debug:
                cv2.imshow('card#%d' % i, img_warp)
        elif debug:
            cv2.imshow('card#%d' % i, np.zeros((1, 1), dtype=np.uint8))
    if out_path is not None:
        cv2.imwrite(out_path, img.astype(np.uint8))
    if display:
        #no_glare = remove_glare(img_copy)
        #img_concat = np.concatenate((img, no_glare), axis=1)
        cv2.imshow('result', img)
        '''
        for i in range(len(obj_list)):
            class_id, confidence, box = obj_list[i]
            left, top, width, height = box
            img_snip = img_copy[max(0, top):min(img.shape[0], top + height),
                                max(0, left):min(img.shape[1], left + width)]
            img_thresh, img_dilate, img_canny, img_hough = find_card(img_snip)
            img_concat = np.concatenate((img_snip, img_thresh, img_dilate, img_canny, img_hough), axis=1)
            cv2.imshow('feature#%d' % i, img_concat)
        '''
        cv2.waitKey(0)
        cv2.destroyAllWindows()
        cv2.imwrite(out_path, img_result.astype(np.uint8))
    return obj_list
    return obj_list, card_name_list, img_result
def detect_video(net, classes, capture, thresh_conf=0.5, thresh_nms=0.4, in_dim=(416, 416), display=True, out_path=None):
def detect_video(net, classes, capture, thresh_conf=0.5, thresh_nms=0.4, in_dim=(416, 416), out_path=None, display=True,
                 debug=False):
    if out_path is not None:
        vid_writer = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), 30,
                                     (round(capture.get(cv2.CAP_PROP_FRAME_WIDTH)),
                                      round(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))))
    max_num_obj = 0
    while True:
        start_time = time.time()
        ret, frame = capture.read()
        if not ret:
            # End of video
            print("End of video. Press any key to exit")
            cv2.waitKey(0)
            break
        img = frame.copy()
        obj_list = detect_frame(net, classes, frame, thresh_conf=thresh_conf, thresh_nms=thresh_nms, in_dim=in_dim,
                                display=False, out_path=None)
        #cnts_rect = find_card(img)
        max_num_obj = max(max_num_obj, len(obj_list))
        if display:
            img_result = frame.copy()
            #img_result = cv2.drawContours(img_result, cnts_rect, -1, (0, 255, 0), 2)
            #for i in range(len(cnts_rect)):
            #    pts = np.float32([p[0] for p in cnts_rect[i]])
            #    img_warp = four_point_transform(img, pts)
            #    cv2.imshow('card#%d' % i, img_warp)
            #for i in range(len(cnts_rect), max_num_obj):
            #    cv2.imshow('card#%d' % i, np.zeros((1, 1), dtype=np.uint8))
            #no_glare = remove_glare(img)
            #img_thresh, img_erode, img_contour = find_card(no_glare)
            #img_concat = np.concatenate((no_glare, img_contour), axis=1)
            for i in range(len(obj_list)):
                class_id, confidence, box = obj_list[i]
                left, top, width, height = box
                offset_ratio = 0.1
                x1 = max(0, int(left - offset_ratio * width))
                x2 = min(img.shape[1], int(left + (1 + offset_ratio) * width))
                y1 = max(0, int(top - offset_ratio * height))
                y2 = min(img.shape[0], int(top + (1 + offset_ratio) * height))
                img_snip = img[y1:y2, x1:x2]
                cnts = find_card(img_snip)
                if len(cnts) > 0:
                    cnt = cnts[-1]
                    pts = np.float32([p[0] for p in cnt])
                    img_warp = four_point_transform(img_snip, pts)
                    img_warp = cv2.resize(img_warp, (card_width, card_height))
                    '''
                    img_art = img_warp[47:249, 22:294]
                    img_art = Image.fromarray(img_art.astype('uint8'), 'RGB')
                    art_hash = ih.phash(img_art, hash_size=32, highfreq_factor=4)
                    card_pool['hash_diff'] = card_pool['art_hash'] - art_hash
                    min_cards = card_pool[card_pool['hash_diff'] == min(card_pool['hash_diff'])]
                    guttersnipe = card_pool[card_pool['name'] == 'Cyclonic Rift']
                    diff = guttersnipe['art_hash'] - art_hash
                    print(diff)
                    card_name = min_cards.iloc[0]['name']
                    #print(min_cards.iloc[0]['name'], min_cards.iloc[0]['hash_diff'])
                    '''
                    img_card = Image.fromarray(img_warp.astype('uint8'), 'RGB')
                    card_hash = ih.phash(img_card, hash_size=32, highfreq_factor=4)
                    card_pool['hash_diff'] = card_pool['card_hash'] - card_hash
                    min_cards = card_pool[card_pool['hash_diff'] == min(card_pool['hash_diff'])]
                    card_name = min_cards.iloc[0]['name']
                    hash_diff = min_cards.iloc[0]['hash_diff']
                    #guttersnipe = card_pool[card_pool['name'] == 'Cyclonic Rift']
                    #diff = guttersnipe['card_hash'] - card_hash
                    #print(diff)
                    #img_thresh, img_dilate, img_contour = find_card(img_snip)
                    #img_concat = np.concatenate((img_snip, img_contour), axis=1)
                    cv2.rectangle(img_warp, (22, 47), (294, 249), (0, 255, 0), 2)
                    cv2.putText(img_warp, card_name + ', ' + str(hash_diff), (0, 50),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
                    cv2.imshow('card#%d' % i, img_warp)
                else:
                    cv2.imshow('card#%d' % i, np.zeros((1, 1), dtype=np.uint8))
        # Use the YOLO model to identify each cards annonymously
        obj_list, card_name_list, img_result = detect_frame(net, classes, frame, thresh_conf=thresh_conf,
                                                            thresh_nms=thresh_nms, in_dim=in_dim, out_path=None,
                                                            display=display, debug=debug)
        if debug:
            max_num_obj = max(max_num_obj, len(obj_list))
            for i in range(len(obj_list), max_num_obj):
                cv2.imshow('card#%d' % i, np.zeros((1, 1), dtype=np.uint8))
        if display:
            cv2.imshow('result', img_result)
            #if len(obj_list) > 0:
            #    cv2.waitKey(0)
        elapsed_ms = (time.time() - start_time) * 1000
        print('Elapsed time: %.2f ms' % elapsed_ms)
        if out_path is not None:
            vid_writer.write(frame.astype(np.uint8))
            vid_writer.write(img_result.astype(np.uint8))
        cv2.waitKey(1)
    if out_path is not None:
@@ -399,7 +364,7 @@
    #cfg_path = 'cfg/tiny_yolo_10.cfg'
    #class_path = "data/obj_10.names"
    weight_path = 'weights/second_general/tiny_yolo_final.weights'
    cfg_path = 'cfg/tiny_yolo.cfg'
    cfg_path = 'cfg/tiny_yolo_old.cfg'
    class_path = 'data/obj.names'
    out_dir = 'out'
    if not os.path.isfile(test_path):
@@ -440,10 +405,27 @@
        detect_frame(net, classes, img, out_path=out_path, thresh_conf=thresh_conf, thresh_nms=thresh_nms)
    else:
        capture = cv2.VideoCapture(0)
        detect_video(net, classes, capture, out_path=out_path, thresh_conf=thresh_conf, thresh_nms=thresh_nms)
        detect_video(net, classes, capture, out_path=out_path, thresh_conf=thresh_conf, thresh_nms=thresh_nms,
                     display=False, debug=False)
        capture.release()
    pass
if __name__ == '__main__':
    '''
    df_list = []
    for set_name in fetch_data.all_set_list:
        csv_name = '%s/csv/%s.csv' % (transform_data.data_dir, set_name)
        df = fetch_data.load_all_cards_text(csv_name)
        df_list.append(df)
        #print(df)
    card_pool = pd.concat(df_list)
    card_pool.reset_index(drop=True, inplace=True)
    card_pool.drop('Unnamed: 0', axis=1, inplace=True, errors='ignore')
    card_pool = calc_image_hashes(card_pool, save_to='card_pool.pck')
    '''
    # csv_name = '%s/csv/%s.csv' % (transform_data.data_dir, 'rtr')
    # card_pool = fetch_data.load_all_cards_text(csv_name)
    # card_pool = calc_image_hashes(card_pool)
    card_pool = pd.read_pickle('card_pool.pck')
    main()
setup_train.py
@@ -7,7 +7,7 @@
def main():
    random.seed()
    data_list = []
    for subdir in glob('%s/train/*_10' % transform_data.data_dir):
    for subdir in glob('%s/train/*_update' % transform_data.data_dir):
        for data in glob(subdir + "/*.jpg"):
            data_list.append(os.path.abspath(data))
    random.shuffle(data_list)
@@ -15,10 +15,10 @@
    test_ratio = 0.1
    test_list = data_list[:int(test_ratio * len(data_list))]
    train_list = data_list[int(test_ratio * len(data_list)):]
    with open('%s/train_10.txt' % transform_data.darknet_dir, 'w') as train_txt:
    with open('%s/train.txt' % transform_data.darknet_dir, 'w') as train_txt:
        for data in train_list:
            train_txt.write(data + '\n')
    with open('%s/test_10.txt' % transform_data.darknet_dir, 'w') as test_txt:
    with open('%s/test.txt' % transform_data.darknet_dir, 'w') as test_txt:
        for data in test_list:
            test_txt.write(data + '\n')
    return
test.txt
Diff too large
train.txt
Diff too large
transform_data.py
@@ -15,7 +15,7 @@
card_mask = cv2.imread('data/mask.png')
data_dir = os.path.abspath('/media/win10/data')
darknet_dir = os.path.abspath('darknet')
darknet_dir = os.path.abspath('.')
def key_pts_to_yolo(key_pts, w_img, h_img):
@@ -347,7 +347,8 @@
                coords_in_gen = [card.coordinate_in_generator(key_pt[0], key_pt[1]) for key_pt in ext_obj.key_pts]
                obj_yolo_info = key_pts_to_yolo(coords_in_gen, self.width, self.height)
                if ext_obj.label == 'card':
                    class_id = self.class_ids[card.info['name']]
                    #class_id = self.class_ids[card.info['name']]
                    class_id = 0
                    out_txt.write(str(class_id) + ' %.6f %.6f %.6f %.6f\n' % obj_yolo_info)
                    pass
                elif ext_obj.label[:ext_obj.label.find[':']] == 'mana_symbol':
@@ -499,17 +500,15 @@
    #bg_images = [cv2.imread('data/frilly_0007.jpg')]
    background = generate_data.Backgrounds(images=bg_images)
    #card_pool = pd.DataFrame()
    #for set_name in fetch_data.all_set_list:
    #    df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
    #    card_pool = card_pool.append(df)
    card_pool = fetch_data.load_all_cards_text('%s/csv/custom.csv' % data_dir)
    card_pool = pd.DataFrame()
    for set_name in fetch_data.all_set_list:
        df = fetch_data.load_all_cards_text('%s/csv/%s.csv' % (data_dir, set_name))
        card_pool = card_pool.append(df)
    class_ids = {}
    with open('%s/obj.names' % data_dir) as names_file:
        class_name_list = names_file.read().splitlines()
        for i in range(len(class_name_list)):
            class_ids[class_name_list[i]] = i
    print(class_ids)
    num_gen = 60000
    num_iter = 1
@@ -545,15 +544,15 @@
            if i % 3 == 0:
                generator.generate_non_obstructive()
                generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_10/%s%d'
                generator.export_training_data(visibility=0.0, out_name='%s/train/non_obstructive_update/%s%d'
                                                                        % (data_dir, out_name, j), aug=seq)
            elif i % 3 == 1:
                generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))
                generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_10/%s%d'
                generator.export_training_data(visibility=0.0, out_name='%s/train/horizontal_span_update/%s%d'
                                                                        % (data_dir, out_name, j), aug=seq)
            else:
                generator.generate_vertical_span(theta=random.uniform(-math.pi, math.pi))
                generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_10/%s%d'
                generator.export_training_data(visibility=0.0, out_name='%s/train/vertical_span_update/%s%d'
                                                                        % (data_dir, out_name, j), aug=seq)
            #generator.generate_horizontal_span(theta=random.uniform(-math.pi, math.pi))