Constantin Wenger
2019-06-13 32dd89caead0dff1c8f23c3535cd357f814bb9a9
opencv_dnn.py
@@ -1,3 +1,4 @@
import argparse
import ast
import collections
import cv2
@@ -8,7 +9,7 @@
import pandas as pd
from PIL import Image
import time
from multiprocessing import Pool
from config import Config
import fetch_data
@@ -21,20 +22,13 @@
https://github.com/hj3yoo/mtg_card_detector/tree/dea64611730c84a59c711c61f7f80948f82bcd31 
"""
def calc_image_hashes(card_pool, save_to=None, hash_size=32, highfreq_factor=4):
    """
    Calculate perceptual hash (pHash) value for each cards in the database, then store them if needed
    :param card_pool: pandas dataframe containing all card information
    :param save_to: path for the pickle file to be saved
    :param hash_size: param for pHash algorithm
    :param highfreq_factor: param for pHash algorithm
    :return: pandas dataframe
    """
    # Since some double-faced cards may result in two different cards, create a new dataframe to store the result
def do_calc(args):
    card_pool = args[0]
    hash_size = args[1]
    new_pool = pd.DataFrame(columns=list(card_pool.columns.values))
    new_pool['card_hash'] = np.NaN
    #new_pool['art_hash'] = np.NaN
    for hs in hash_size:
        new_pool['card_hash_%d' % hs] = np.NaN
        #new_pool['art_hash_%d' % hs] = np.NaN
    for ind, card_info in card_pool.iterrows():
        if ind % 100 == 0:
            print('Calculating hashes: %dth card' % ind)
@@ -68,20 +62,38 @@
                print('WARNING: card %s is not found!' % img_name)
            # Compute value of the card's perceptual hash, then store it to the database
            '''
            img_art = Image.fromarray(card_img[121:580, 63:685])  # For 745*1040 size card image
            art_hash = ih.phash(img_art, hash_size=hash_size, highfreq_factor=highfreq_factor)
            card_info['art_hash'] = art_hash
            '''
            #img_art = Image.fromarray(card_img[121:580, 63:685])  # For 745*1040 size card image
            img_card = Image.fromarray(card_img)
            card_hash = ih.phash(img_card, hash_size=hash_size, highfreq_factor=highfreq_factor)
            card_info['card_hash'] = card_hash
            for hs in hash_size:
                card_hash = ih.phash(img_card, hash_size=hs)
                card_info['card_hash_%d' % hs] = card_hash
                #art_hash = ih.phash(img_art, hash_size=hs)
                #card_info['art_hash_%d' % hs] = art_hash
            new_pool.loc[0 if new_pool.empty else new_pool.index.max() + 1] = card_info
    return new_pool
    # Remove uselesss fields, then pickle it if needed
    new_pool = new_pool[['artist', 'border_color', 'collector_number', 'color_identity', 'colors', 'flavor_text',
                         'image_uris', 'mana_cost', 'legalities', 'name', 'oracle_text', 'rarity', 'type_line',
                         'set', 'set_name', 'power', 'toughness', 'art_hash', 'card_hash']]
def calc_image_hashes(card_pool, save_to=None, hash_size=None):
    """
    Calculate perceptual hash (pHash) value for each cards in the database, then store them if needed
    :param card_pool: pandas dataframe containing all card information
    :param save_to: path for the pickle file to be saved
    :param hash_size: param for pHash algorithm
    :return: pandas dataframe
    """
    if hash_size is None:
        hash_size = [16, 32]
    elif isinstance(hash_size, int):
        hash_size = [hash_size]
    num_cores = 15
    num_partitions = 60
    pool = Pool(num_cores)
    df_split = np.array_split(card_pool, num_partitions)
    new_pool = pd.concat(pool.map(do_calc, [(split, hash_size) for split in df_split]))
    pool.close()
    pool.join()
    # Since some double-faced cards may result in two different cards, create a new dataframe to store the result
    if save_to is not None:
        new_pool.to_pickle(save_to)
    return new_pool
@@ -166,72 +178,6 @@
    return warped
'''
# The following functions are only used in conjunction with YOLO, and is deprecated:
# - get_outputs_names()
# - post_process()
# - draw_pred()
# Get the names of the output layers
def get_outputs_names(net):
    # Get the names of all the layers in the network
    layers_names = net.getLayerNames()
    # Get the names of the output layers, i.e. the layers with unconnected outputs
    return [layers_names[i[0] - 1] for i in net.getUnconnectedOutLayers()]
# Remove the bounding boxes with low confidence using non-maxima suppression
# https://www.learnopencv.com/deep-learning-based-object-detection-using-yolov3-with-opencv-python-c/
def post_process(frame, outs, thresh_conf, thresh_nms):
    frame_height = frame.shape[0]
    frame_width = frame.shape[1]
    # Scan through all the bounding boxes output from the network and keep only the
    # ones with high confidence scores. Assign the box's class label as the class with the highest score.
    class_ids = []
    confidences = []
    boxes = []
    for out in outs:
        for detection in out:
            scores = detection[5:]
            class_id = np.argmax(scores)
            confidence = scores[class_id]
            if confidence > thresh_conf:
                center_x = int(detection[0] * frame_width)
                center_y = int(detection[1] * frame_height)
                width = int(detection[2] * frame_width)
                height = int(detection[3] * frame_height)
                left = int(center_x - width / 2)
                top = int(center_y - height / 2)
                class_ids.append(class_id)
                confidences.append(float(confidence))
                boxes.append([left, top, width, height])
    # Perform non maximum suppression to eliminate redundant overlapping boxes with lower confidences.
    indices = [ind[0] for ind in cv2.dnn.NMSBoxes(boxes, confidences, thresh_conf, thresh_nms)]
    ret = [[class_ids[i], confidences[i], boxes[i]] for i in indices]
    return ret
# Draw the predicted bounding box
def draw_pred(frame, class_id, classes, conf, left, top, right, bottom):
    # Draw a bounding box.
    cv2.rectangle(frame, (left, top), (right, bottom), (0, 0, 255))
    label = '%.2f' % conf
    # Get the label for the class name and its confidence
    if classes:
        assert (class_id < len(classes))
        label = '%s:%s' % (classes[class_id], label)
    # Display the label at the top of the bounding box
    label_size, base_line = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)
    top = max(top, label_size[1])
    cv2.putText(frame, label, (left, top), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255))
'''
def remove_glare(img):
    """
    Reduce the effect of glaring in the image
@@ -282,7 +228,7 @@
    img_erode = cv2.erode(img_dilate, kernel, iterations=1)
    # Find the contour
    _, cnts, hier = cv2.findContours(img_erode, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    cnts, hier = cv2.findContours(img_erode, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    if len(cnts) == 0:
        #print('no contours')
        return []
@@ -350,7 +296,9 @@
        if os.path.exists(img_name):
            card_img = cv2.imread(img_name)
        else:
            card_img = np.ones((h_card, w_card))
            card_img = np.ones((h_card, w_card, 3)) * 255
            cv2.putText(card_img, 'X', ((w_card - int(txt_scale * 25)) // 2, (h_card + int(txt_scale * 25)) // 2),
                        cv2.FONT_HERSHEY_SIMPLEX, txt_scale, (0, 0, 0), 2)
        # Insert the card image, card name, and confidence bar to the graph
        img_graph[y_anchor:y_anchor + h_card, x_anchor:x_anchor + w_card] = card_img
@@ -369,14 +317,13 @@
    return img_graph
def detect_frame(img, card_pool, hash_size=32, highfreq_factor=4, size_thresh=10000,
def detect_frame(img, card_pool, hash_size=32, size_thresh=10000,
                 out_path=None, display=True, debug=False):
    """
    Identify all cards in the input frame, display or save the frame if needed
    :param img: input frame
    :param card_pool: pandas dataframe of all card's information
    :param hash_size: param for pHash algorithm
    :param highfreq_factor: param for pHash algorithm
    :param size_thresh: threshold for size (in pixel) of the contour to be a candidate
    :param out_path: path to save the result
    :param display: flag for displaying the result
@@ -402,13 +349,14 @@
        '''
        img_art = img_warp[47:249, 22:294]
        img_art = Image.fromarray(img_art.astype('uint8'), 'RGB')
        art_hash = ih.phash(img_art, hash_size=hash_size, highfreq_factor=highfreq_factor).hash.flatten()
        art_hash = ih.phash(img_art, hash_size=hash_size).hash.flatten()
        card_pool['hash_diff'] = card_pool['art_hash'].apply(lambda x: np.count_nonzero(x != art_hash))
        '''
        img_card = Image.fromarray(img_warp.astype('uint8'), 'RGB')
        # the stored values of hashes in the dataframe is pre-emptively flattened already to minimize computation time
        card_hash = ih.phash(img_card, hash_size=hash_size, highfreq_factor=highfreq_factor).hash.flatten()
        card_pool['hash_diff'] = card_pool['card_hash'].apply(lambda x: np.count_nonzero(x != card_hash))
        card_hash = ih.phash(img_card, hash_size=hash_size).hash.flatten()
        card_pool['hash_diff'] = card_pool['card_hash_%d' % hash_size]
        card_pool['hash_diff'] = card_pool['hash_diff'].apply(lambda x: np.count_nonzero(x != card_hash))
        min_card = card_pool[card_pool['hash_diff'] == min(card_pool['hash_diff'])].iloc[0]
        card_name = min_card['name']
        card_set = min_card['set']
@@ -417,11 +365,12 @@
        # Render the result, and display them if needed
        cv2.drawContours(img_result, [cnt], -1, (0, 255, 0), 2)
        cv2.putText(img_result, card_name, (pts[0][0], pts[0][1]), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        cv2.putText(img_result, card_name, (min(pts[0][0], pts[1][0]), min(pts[0][1], pts[1][1])),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2)
        if debug:
            # cv2.rectangle(img_warp, (22, 47), (294, 249), (0, 255, 0), 2)
            cv2.putText(img_warp, card_name + ', ' + str(hash_diff), (0, 50),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
            cv2.putText(img_warp, card_name + ':' + card_set + ', ' + str(hash_diff), (0, 20),
                        cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
            cv2.imshow('card#%d' % i, img_warp)
    if display:
        cv2.imshow('Result', img_result)
@@ -432,14 +381,13 @@
    return det_cards, img_result
def detect_video(capture, card_pool, hash_size=32, highfreq_factor=4, size_thresh=10000,
def detect_video(capture, card_pool, hash_size=32, size_thresh=10000,
                 out_path=None, display=True, show_graph=True, debug=False):
    """
    Identify all cards in the continuous video stream, display or save the result if needed
    :param capture: input video stream
    :param card_pool: pandas dataframe of all card's information
    :param hash_size: param for pHash algorithm
    :param highfreq_factor: param for pHash algorithm
    :param size_thresh: threshold for size (in pixel) of the contour to be a candidate
    :param out_path: path to save the result
    :param display: flag for displaying the result
@@ -471,8 +419,8 @@
                cv2.waitKey(0)
                break
            # Detect all cards from the current frame
            det_cards, img_result = detect_frame(frame, card_pool, hash_size=hash_size, highfreq_factor=highfreq_factor,
                                                 size_thresh=size_thresh, out_path=None, display=False, debug=debug)
            det_cards, img_result = detect_frame(frame, card_pool, hash_size=hash_size, size_thresh=size_thresh,
                                                 out_path=None, display=False, debug=debug)
            if show_graph:
                # If the card was already detected in the previous frame, append 1 to the list
                # If the card previously detected was not found in this trame, append 0 to the list
@@ -528,18 +476,15 @@
        cv2.destroyAllWindows()
def main():
def main(args):
    # Specify paths for all necessary files
    #test_path = os.path.abspath('test_file/test4.mp4')
    test_path = None
    out_dir = 'out'
    hash_size = 32
    highfreq_factor = 4
    pck_path = os.path.abspath('card_pool_%d_%d.pck' % (hash_size, highfreq_factor))
    hash_sizes = {16, 32}
    hash_sizes.add(args.hash_size)
    pck_path = os.path.abspath('card_pool.pck')
    if os.path.isfile(pck_path):
        card_pool = pd.read_pickle(pck_path)
    else:
        print('Warning: pickle for card database %s is not found!' % pck_path)
        # Merge database for all cards, then calculate pHash values of each, store them
        df_list = []
        for set_name in Config.all_set_list:
@@ -549,44 +494,72 @@
        card_pool = pd.concat(df_list, sort=True)
        card_pool.reset_index(drop=True, inplace=True)
        card_pool.drop('Unnamed: 0', axis=1, inplace=True, errors='ignore')
        card_pool = calc_image_hashes(card_pool, save_to=pck_path, hash_size=hash_sizes)
    ch_key = 'card_hash_%d' % args.hash_size
    if ch_key not in card_pool.columns:
        # we did not generate this hash_size yet
        print('We need to add hash_size=%d' % (args.hash_size,))
        card_pool = calc_image_hashes(card_pool, save_to=pck_path, hash_size=[args.hash_size])
        card_pool = calc_image_hashes(card_pool, save_to=pck_path, hash_size=hash_size, highfreq_factor=highfreq_factor)
    card_pool = card_pool[['name', 'set', 'collector_number', 'card_hash']]
    card_pool = card_pool[['name', 'set', 'collector_number', ch_key]]
    # Processing time is almost linear to the size of the database
    # Program can be much faster if the search scope for the card can be reduced
    card_pool = card_pool[card_pool['set'].isin(Config.set_2003_list)]
    # ImageHash is basically just one numpy.ndarray with (hash_size)^2 number of bits. pre-emptively flattening it
    # significantly increases speed for subtracting hashes in the future.
    card_pool['card_hash'] = card_pool['card_hash'].apply(lambda x: x.hash.flatten())
    card_pool[ch_key] = card_pool[ch_key].apply(lambda x: x.hash.flatten())
    # If the test file isn't given, use webcam to capture video
    if test_path is None:
    if args.in_path is None:
        capture = cv2.VideoCapture(0)
        detect_video(capture, card_pool, out_path='%s/result.avi' % out_dir, display=True, show_graph=True, debug=False)
        detect_video(capture, card_pool, hash_size=args.hash_size, out_path='%s/result.avi' % args.out_path,
                     display=args.display, show_graph=args.show_graph, debug=args.debug)
        capture.release()
    else:
        # Save the detection result if out_dir is provided
        if out_dir is None or out_dir == '':
        # Save the detection result if args.out_path is provided
        if args.out_path is None:
            out_path = None
        else:
            f_name = os.path.split(test_path)[1]
            out_path = '%s/%s.avi' % (out_dir, f_name[:f_name.find('.')])
            f_name = os.path.split(args.in_path)[1]
            out_path = '%s/%s.avi' % (args.out_path, f_name[:f_name.find('.')])
        if not os.path.isfile(test_path):
            print('The test file %s doesn\'t exist!' % os.path.abspath(test_path))
        if not os.path.isfile(args.in_path):
            print('The test file %s doesn\'t exist!' % os.path.abspath(args.in_path))
            return
        # Check if test file is image or video
        test_ext = test_path[test_path.find('.') + 1:]
        test_ext = args.in_path[args.in_path.find('.') + 1:]
        if test_ext in ['jpg', 'jpeg', 'bmp', 'png', 'tiff']:
            # Test file is an image
            img = cv2.imread(test_path)
            detect_frame(img, card_pool, out_path=out_path)
            img = cv2.imread(args.in_path)
            detect_frame(img, card_pool, hash_size=args.hash_size, out_path=out_path, display=args.display,
                         debug=args.debug)
        else:
            # Test file is a video
            capture = cv2.VideoCapture(test_path)
            detect_video(capture, card_pool, out_path=out_path, display=True, show_graph=True, debug=False)
            capture = cv2.VideoCapture(args.in_path)
            detect_video(capture, card_pool, hash_size=args.hash_size, out_path=out_path, display=args.display,
                         show_graph=args.show_graph, debug=args.debug)
            capture.release()
    pass
if __name__ == '__main__':
    main()
    parser = argparse.ArgumentParser()
    parser.add_argument('-i', '--in', dest='in_path', help='Path of the input file. For webcam, leave it blank',
                        type=str)
    parser.add_argument('-o', '--out', dest='out_path', help='Path of the output directory to save the result',
                        type=str)
    parser.add_argument('-hs', '--hash_size', dest='hash_size',
                        help='Size of the hash for pHash algorithm', type=int, default=16)
    parser.add_argument('-dsp', '--display', dest='display', help='Display the result', action='store_true',
                        default=False)
    parser.add_argument('-dbg', '--debug', dest='debug', help='Enable debug mode', action='store_true', default=False)
    parser.add_argument('-gph', '--show_graph', dest='show_graph', help='Display the graph for video output',
                        action='store_true', default=False)
    args = parser.parse_args()
    if not args.display and args.out_path is None:
        # Then why the heck are you running this thing in the first place?
        print('The program isn\'t displaying nor saving any output file. Please change the setting and try again.')
        exit()
    main(args)