first version of something that torch-rnn can use with streaming
This commit is contained in:
parent
d4b5ef2104
commit
0132345ebe
5 changed files with 203 additions and 3 deletions
1
data/mtgvocab.json
Normal file
1
data/mtgvocab.json
Normal file
|
@ -0,0 +1 @@
|
||||||
|
{"idx_to_token": {"1": "\n", "2": " ", "3": "\"", "4": "%", "5": "&", "6": "'", "7": "*", "8": "+", "9": ",", "10": "-", "11": ".", "12": "/", "13": "0", "14": "1", "15": "2", "16": "3", "17": "4", "18": "5", "19": "6", "20": "7", "21": "8", "22": "9", "23": ":", "24": "=", "25": "@", "26": "A", "27": "B", "28": "C", "29": "E", "30": "G", "31": "L", "32": "N", "33": "O", "34": "P", "35": "Q", "36": "R", "37": "S", "38": "T", "39": "U", "40": "W", "41": "X", "42": "Y", "43": "[", "44": "\\", "45": "]", "46": "^", "47": "a", "48": "b", "49": "c", "50": "d", "51": "e", "52": "f", "53": "g", "54": "h", "55": "i", "56": "j", "57": "k", "58": "l", "59": "m", "60": "n", "61": "o", "62": "p", "63": "q", "64": "r", "65": "s", "66": "t", "67": "u", "68": "v", "69": "w", "70": "x", "71": "y", "72": "z", "73": "{", "74": "|", "75": "}", "76": "~"}, "token_to_idx": {"\n": 1, " ": 2, "\"": 3, "%": 4, "'": 6, "&": 5, "+": 8, "*": 7, "-": 10, ",": 9, "/": 12, ".": 11, "1": 14, "0": 13, "3": 16, "2": 15, "5": 18, "4": 17, "7": 20, "6": 19, "9": 22, "8": 21, ":": 23, "=": 24, "A": 26, "@": 25, "C": 28, "B": 27, "E": 29, "G": 30, "L": 31, "O": 33, "N": 32, "Q": 35, "P": 34, "S": 37, "R": 36, "U": 39, "T": 38, "W": 40, "Y": 42, "X": 41, "[": 43, "]": 45, "\\": 44, "^": 46, "a": 47, "c": 49, "b": 48, "e": 51, "d": 50, "g": 53, "f": 52, "i": 55, "h": 54, "k": 57, "j": 56, "m": 59, "l": 58, "o": 61, "n": 60, "q": 63, "p": 62, "s": 65, "r": 64, "u": 67, "t": 66, "w": 69, "v": 68, "y": 71, "x": 70, "{": 73, "z": 72, "}": 75, "|": 74, "~": 76}}
|
|
@ -565,9 +565,9 @@ class Card:
|
||||||
# the NN representation, use str() or format() for output intended for human
|
# the NN representation, use str() or format() for output intended for human
|
||||||
# readers.
|
# readers.
|
||||||
|
|
||||||
def encode(self, fmt_ordered = fmt_ordered_default, fmt_labeled = None,
|
def encode(self, fmt_ordered = fmt_ordered_default, fmt_labeled = fmt_labeled_default,
|
||||||
fieldsep = utils.fieldsep, randomize_fields = False, randomize_mana = False,
|
fieldsep = utils.fieldsep, initial_sep = True, final_sep = True,
|
||||||
initial_sep = True, final_sep = True):
|
randomize_fields = False, randomize_mana = False, randomize_lines = False):
|
||||||
outfields = []
|
outfields = []
|
||||||
|
|
||||||
for field in fmt_ordered:
|
for field in fmt_ordered:
|
||||||
|
@ -581,6 +581,8 @@ class Card:
|
||||||
outfield_str = outfield.encode(randomize = randomize_mana)
|
outfield_str = outfield.encode(randomize = randomize_mana)
|
||||||
elif isinstance(outfield, Manatext):
|
elif isinstance(outfield, Manatext):
|
||||||
outfield_str = outfield.encode(randomize = randomize_mana)
|
outfield_str = outfield.encode(randomize = randomize_mana)
|
||||||
|
if randomize_lines:
|
||||||
|
outfield_str = transforms.randomize_lines(outfield_str)
|
||||||
else:
|
else:
|
||||||
outfield_str = outfield
|
outfield_str = outfield
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
# transform passes used to encode / decode cards
|
# transform passes used to encode / decode cards
|
||||||
import re
|
import re
|
||||||
|
import random
|
||||||
|
|
||||||
# These could probably use a little love... They tend to hardcode in lots
|
# These could probably use a little love... They tend to hardcode in lots
|
||||||
# of things very specific to the mtgjson format.
|
# of things very specific to the mtgjson format.
|
||||||
|
@ -482,6 +483,7 @@ def text_pass_11_linetrans(s):
|
||||||
alllines = prelines + keylines + mainlines + postlines
|
alllines = prelines + keylines + mainlines + postlines
|
||||||
return utils.newline.join(alllines)
|
return utils.newline.join(alllines)
|
||||||
|
|
||||||
|
|
||||||
# randomize the order of the lines
|
# randomize the order of the lines
|
||||||
# not a text pass, intended to be invoked dynamically when encoding a card
|
# not a text pass, intended to be invoked dynamically when encoding a card
|
||||||
# call this on fully encoded text, with mana symbols expanded
|
# call this on fully encoded text, with mana symbols expanded
|
||||||
|
@ -491,6 +493,7 @@ def separate_lines(text):
|
||||||
return [],[],[],[],[]
|
return [],[],[],[],[]
|
||||||
|
|
||||||
preline_search = ['equip', 'fortify', 'enchant ', 'bestow']
|
preline_search = ['equip', 'fortify', 'enchant ', 'bestow']
|
||||||
|
# probably could use optimization with a regex
|
||||||
costline_search = [
|
costline_search = [
|
||||||
'multikicker', 'kicker', 'suspend', 'echo', 'awaken',
|
'multikicker', 'kicker', 'suspend', 'echo', 'awaken',
|
||||||
'buyback', 'dash', 'entwine', 'evoke', 'flashback',
|
'buyback', 'dash', 'entwine', 'evoke', 'flashback',
|
||||||
|
@ -537,6 +540,48 @@ def separate_lines(text):
|
||||||
|
|
||||||
return prelines, keylines, mainlines, costlines, postlines
|
return prelines, keylines, mainlines, costlines, postlines
|
||||||
|
|
||||||
|
choice_re = re.compile(re.escape(utils.choice_open_delimiter) + r'.*' +
|
||||||
|
re.escape(utils.choice_close_delimiter))
|
||||||
|
choice_divider = ' ' + utils.bullet_marker + ' '
|
||||||
|
def randomize_choice(line):
|
||||||
|
choices = re.findall(choice_re, line)
|
||||||
|
if len(choices) < 1:
|
||||||
|
return line
|
||||||
|
new_line = line
|
||||||
|
for choice in choices:
|
||||||
|
parts = choice[1:-1].split(choice_divider)
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
choiceparts = parts[1:]
|
||||||
|
random.shuffle(choiceparts)
|
||||||
|
new_line = new_line.replace(choice,
|
||||||
|
utils.choice_open_delimiter +
|
||||||
|
choice_divider.join(parts[:1] + choiceparts) +
|
||||||
|
utils.choice_close_delimiter,
|
||||||
|
1)
|
||||||
|
return new_line
|
||||||
|
|
||||||
|
|
||||||
|
def randomize_lines(text):
|
||||||
|
if text == '' or 'level up' in text:
|
||||||
|
return text
|
||||||
|
|
||||||
|
prelines, keylines, mainlines, costlines, postlines = separate_lines(text)
|
||||||
|
random.shuffle(prelines)
|
||||||
|
random.shuffle(keylines)
|
||||||
|
new_mainlines = []
|
||||||
|
for line in mainlines:
|
||||||
|
if line.endswith(utils.choice_close_delimiter):
|
||||||
|
new_mainlines.append(randomize_choice(line))
|
||||||
|
# elif utils.choice_open_delimiter in line or utils.choice_close_delimiter in line:
|
||||||
|
# print(line)
|
||||||
|
else:
|
||||||
|
new_mainlines.append(line)
|
||||||
|
random.shuffle(new_mainlines)
|
||||||
|
random.shuffle(costlines)
|
||||||
|
#random.shuffle(postlines) # only one kind ever (countertype)
|
||||||
|
return utils.newline.join(prelines+keylines+new_mainlines+costlines+postlines)
|
||||||
|
|
||||||
|
|
||||||
# Text unpasses, for decoding. All assume the text inside a Manatext, so don't do anything
|
# Text unpasses, for decoding. All assume the text inside a Manatext, so don't do anything
|
||||||
# weird with the mana cost symbol.
|
# weird with the mana cost symbol.
|
||||||
|
|
|
@ -2,11 +2,13 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
libdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../lib')
|
libdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../lib')
|
||||||
sys.path.append(libdir)
|
sys.path.append(libdir)
|
||||||
import utils
|
import utils
|
||||||
import jdecode
|
import jdecode
|
||||||
|
import cardlib
|
||||||
import transforms
|
import transforms
|
||||||
|
|
||||||
def check_lines(fname):
|
def check_lines(fname):
|
||||||
|
@ -122,6 +124,28 @@ def check_vocab(fname):
|
||||||
print(card.encode())
|
print(card.encode())
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def check_characters(fname, vname):
|
||||||
|
cards = jdecode.mtg_open_file(fname, verbose=True, linetrans=True)
|
||||||
|
|
||||||
|
tokens = {c for c in utils.cardsep}
|
||||||
|
for card in cards:
|
||||||
|
for c in card.encode():
|
||||||
|
tokens.add(c)
|
||||||
|
|
||||||
|
token_to_idx = {tok:i+1 for i, tok in enumerate(sorted(tokens))}
|
||||||
|
idx_to_token = {i+1:tok for i, tok in enumerate(sorted(tokens))}
|
||||||
|
|
||||||
|
print('Vocabulary: ({:d} symbols)'.format(len(token_to_idx)))
|
||||||
|
for token in sorted(token_to_idx):
|
||||||
|
print('{:8s} : {:4d}'.format(repr(token), token_to_idx[token]))
|
||||||
|
|
||||||
|
# compliant with torch-rnn
|
||||||
|
if vname:
|
||||||
|
json_data = {'token_to_idx':token_to_idx, 'idx_to_token':idx_to_token}
|
||||||
|
print('writing vocabulary to {:s}'.format(vname))
|
||||||
|
with open(vname, 'w') as f:
|
||||||
|
json.dump(json_data, f)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import argparse
|
import argparse
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
@ -132,11 +156,17 @@ if __name__ == '__main__':
|
||||||
help='show behavior of line separation')
|
help='show behavior of line separation')
|
||||||
parser.add_argument('-vocab', action='store_true',
|
parser.add_argument('-vocab', action='store_true',
|
||||||
help='show vocabulary counts from encoded card text')
|
help='show vocabulary counts from encoded card text')
|
||||||
|
parser.add_argument('-chars', action='store_true',
|
||||||
|
help='generate and display vocabulary of characters used in encoding')
|
||||||
|
parser.add_argument('--vocab_name', default=None,
|
||||||
|
help='json file to write vocabulary to')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
if args.lines:
|
if args.lines:
|
||||||
check_lines(args.infile)
|
check_lines(args.infile)
|
||||||
if args.vocab:
|
if args.vocab:
|
||||||
check_vocab(args.infile)
|
check_vocab(args.infile)
|
||||||
|
if args.chars:
|
||||||
|
check_characters(args.infile, args.vocab_name)
|
||||||
|
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
122
scripts/streamcards.py
Executable file
122
scripts/streamcards.py
Executable file
|
@ -0,0 +1,122 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# -- STOLEN FROM torch-rnn/scripts/streamfile.py -- #
|
||||||
|
|
||||||
|
import os
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import signal
|
||||||
|
import traceback
|
||||||
|
import psutil
|
||||||
|
|
||||||
|
# correctly setting up a stream that won't get orphaned and left clutting the operating
|
||||||
|
# system proceeds in 3 parts:
|
||||||
|
# 1) invoke install_suicide_handlers() to ensure correct behavior on interrupt
|
||||||
|
# 2) get threads by invoking spawn_stream_threads
|
||||||
|
# 3) invoke wait_and_kill_self_noreturn(threads)
|
||||||
|
# or, use the handy wrapper that does it for you
|
||||||
|
|
||||||
|
def spawn_stream_threads(fds, runthread, mkargs):
|
||||||
|
threads = []
|
||||||
|
for i, fd in enumerate(fds):
|
||||||
|
stream_thread = threading.Thread(target=runthread, args=mkargs(i, fd))
|
||||||
|
stream_thread.daemon = True
|
||||||
|
stream_thread.start()
|
||||||
|
threads.append(stream_thread)
|
||||||
|
return threads
|
||||||
|
|
||||||
|
def force_kill_self_noreturn():
|
||||||
|
# We have a strange issue here, which is that our threads will refuse to die
|
||||||
|
# to a normal exit() or sys.exit() because they're all blocked in write() calls
|
||||||
|
# on full pipes; the simplest workaround seems to be to ask the OS to terminate us.
|
||||||
|
# This kinda works, but...
|
||||||
|
#os.kill(os.getpid(), signal.SIGTERM)
|
||||||
|
# psutil might have useful features like checking if the pid has been reused before killing it.
|
||||||
|
# Also we might have child processes like l2e luajits to think about.
|
||||||
|
me = psutil.Process(os.getpid())
|
||||||
|
for child in me.children(recursive=True):
|
||||||
|
child.terminate()
|
||||||
|
me.terminate()
|
||||||
|
|
||||||
|
def handler_kill_self(signum, frame):
|
||||||
|
if signum != signal.SIGQUIT:
|
||||||
|
traceback.print_stack(frame)
|
||||||
|
print('caught signal {:d} - streamer sending SIGTERM to self'.format(signum))
|
||||||
|
force_kill_self_noreturn()
|
||||||
|
|
||||||
|
def install_suicide_handlers():
|
||||||
|
for sig in [signal.SIGHUP, signal.SIGINT, signal.SIGQUIT]:
|
||||||
|
signal.signal(sig, handler_kill_self)
|
||||||
|
|
||||||
|
def wait_and_kill_self_noreturn(threads):
|
||||||
|
running = True
|
||||||
|
while running:
|
||||||
|
running = False
|
||||||
|
for thread in threads:
|
||||||
|
if thread.is_alive():
|
||||||
|
running = True
|
||||||
|
if(os.getppid() <= 1):
|
||||||
|
# exit if parent process died (and we were reparented to init)
|
||||||
|
break
|
||||||
|
time.sleep(1)
|
||||||
|
force_kill_self_noreturn()
|
||||||
|
|
||||||
|
def streaming_noreturn(fds, write_stream, mkargs):
|
||||||
|
install_suicide_handlers()
|
||||||
|
threads = spawn_stream_threads(fds, write_stream, mkargs)
|
||||||
|
wait_and_kill_self_noreturn(threads)
|
||||||
|
assert False, 'should not return from streaming'
|
||||||
|
|
||||||
|
# -- END STOLEN FROM torch-rnn/scripts/streamfile.py -- #
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import random
|
||||||
|
|
||||||
|
libdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../lib')
|
||||||
|
sys.path.append(libdir)
|
||||||
|
import utils
|
||||||
|
import jdecode
|
||||||
|
import transforms
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
fds = args.fds
|
||||||
|
fname = args.fname
|
||||||
|
block_size = args.block_size
|
||||||
|
main_seed = args.seed if args.seed != 0 else None
|
||||||
|
|
||||||
|
# simple default encoding for now, will add more options with the curriculum
|
||||||
|
# learning feature
|
||||||
|
|
||||||
|
cards = jdecode.mtg_open_file(fname, verbose=True, linetrans=True)
|
||||||
|
|
||||||
|
def write_stream(i, fd):
|
||||||
|
local_random = random.Random(main_seed)
|
||||||
|
local_random.jumpahead(i)
|
||||||
|
local_cards = [card for card in cards]
|
||||||
|
with open('/proc/self/fd/'+str(fd), 'wt') as f:
|
||||||
|
while True:
|
||||||
|
local_random.shuffle(local_cards)
|
||||||
|
for card in local_cards:
|
||||||
|
f.write(card.encode(randomize_mana=True, randomize_lines=True))
|
||||||
|
f.write(utils.cardsep)
|
||||||
|
|
||||||
|
def mkargs(i, fd):
|
||||||
|
return i, fd
|
||||||
|
|
||||||
|
streaming_noreturn(fds, write_stream, mkargs)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('fds', type=int, nargs='+',
|
||||||
|
help='file descriptors to write streams to')
|
||||||
|
parser.add_argument('-f', '--fname', default=os.path.join(libdir, '../data/output.txt'),
|
||||||
|
help='file to read cards from')
|
||||||
|
parser.add_argument('-n', '--block_size', type=int, default=10000,
|
||||||
|
help='number of characters each stream should read/write at a time')
|
||||||
|
parser.add_argument('-s', '--seed', type=int, default=0,
|
||||||
|
help='random seed')
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
main(args)
|
Loading…
Reference in a new issue