mtgencode/scripts/autosample.py
2015-12-02 11:01:45 -08:00

90 lines
3.5 KiB
Python
Executable file

#!/usr/bin/env python
import sys
import os
import subprocess
import random
def extract_cp_name(name):
# "lm_lstm_epoch50.00_0.1870.t7"
if not (name[:13] == 'lm_lstm_epoch' and name[-3:] == '.t7'):
return None
name = name[13:-3]
(epoch, vloss) = tuple(name.split('_'))
return (float(epoch), float(vloss))
def sample(cp, temp, count, seed = None, ident = 'output'):
if seed is None:
seed = random.randint(-1000000000, 1000000000)
outfile = cp + '.' + ident + '.' + str(temp) + '.txt'
cmd = ('th sample.lua ' + cp
+ ' -temperature ' + str(temp)
+ ' -length ' + str(count)
+ ' -seed ' + str(seed)
+ ' >> ' + outfile)
if os.path.exists(outfile):
print(outfile + ' already exists, skipping')
return False
else:
# UNSAFE SHELL=TRUE FOR CONVENIENCE
subprocess.call('echo "' + cmd + '" | tee ' + outfile, shell=True)
subprocess.call(cmd, shell=True)
def find_best_cp(cpdir):
best = None
best_cp = None
for path in os.listdir(cpdir):
fullpath = os.path.join(cpdir, path)
if os.path.isfile(fullpath):
extracted = extract_cp_name(path)
if not extracted is None:
(epoch, vloss) = extracted
if best is None or vloss < best:
best = vloss
best_cp = fullpath
return best_cp
def process_dir(cpdir, temp, count, seed = None, ident = 'output', verbose = False):
if verbose:
print('processing ' + cpdir)
best_cp = find_best_cp(cpdir)
if not best_cp is None:
sample(best_cp, temp, count, seed=seed, ident=ident)
for path in os.listdir(cpdir):
fullpath = os.path.join(cpdir, path)
if os.path.isdir(fullpath):
process_dir(fullpath, temp, count, seed=seed, ident=ident, verbose=verbose)
def main(rnndir, cpdir, temp, count, seed = None, ident = 'output', verbose = False):
if not os.path.isdir(rnndir):
raise ValueError('bad rnndir: ' + rnndir)
if not os.path.isdir(cpdir):
raise ValueError('bad cpdir: ' + cpdir)
os.chdir(rnndir)
process_dir(cpdir, temp, count, seed=seed, ident=ident, verbose=verbose)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('rnndir', #nargs='?'. default=None,
help='base rnn directory, must contain sample.lua')
parser.add_argument('cpdir', #nargs='?', default=None,
help='checkpoint directory, all subdirectories will be processed')
parser.add_argument('-t', '--temperature', action='store', default='1.0',
help='sampling temperature')
parser.add_argument('-c', '--count', action='store', default='1000000',
help='number of characters to sample each time')
parser.add_argument('-s', '--seed', action='store', default=None,
help='fixed seed; if not present, a random seed will be used')
parser.add_argument('-i', '--ident', action='store', default='output',
help='identifier to include in the output filenames')
parser.add_argument('-v', '--verbose', action='store_true',
help='verbose output')
args = parser.parse_args()
if args.seed is None:
seed = None
else:
seed = int(args.seed)
main(args.rnndir, args.cpdir, float(args.temperature), int(args.count),
seed=seed, ident=args.ident, verbose = args.verbose)