diff --git a/lib/cbow.py b/lib/cbow.py index 8899bae..8c89a19 100644 --- a/lib/cbow.py +++ b/lib/cbow.py @@ -65,6 +65,11 @@ def makevector(vocabulary,vecs,sequence): res = v else: res = [x + y for x, y in zip(res,v)] + + # bad things happen if we have a vector of only unknown words + if res is None: + return [0.0]*len(vecs[0]) + length = math.sqrt(sum([res[i] * res[i] for i in range(0,len(res))])) for i in range(0,len(res)): res[i] /= length diff --git a/scripts/autosample.py b/scripts/autosample.py index 02785d3..6587874 100755 --- a/scripts/autosample.py +++ b/scripts/autosample.py @@ -88,3 +88,4 @@ if __name__ == '__main__': seed = int(args.seed) main(args.rnndir, args.cpdir, float(args.temperature), int(args.count), seed=seed, ident=args.ident, verbose = args.verbose) + exit(0) diff --git a/scripts/collect_checkpoints.py b/scripts/collect_checkpoints.py new file mode 100755 index 0000000..7c5f287 --- /dev/null +++ b/scripts/collect_checkpoints.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +import sys +import os +import shutil + +def identify_checkpoints(basedir, ident): + cp_infos = [] + for path in os.listdir(cpdir): + fullpath = os.path.join(basedir, path) + if not os.path.isfile(fullpath): + continue + if not (name[:13] == 'lm_lstm_epoch' and name[-4:] == '.txt'): + continue + if not ident in path: + continue + # attempt super hacky parsing + inner = path[13:-4] + halves = inner.split('_') + if not len(halves) == 2: + continue + parts = halves[1].split('.') + if not len(parts) == 6: + continue + # lm_lstm_epoch[25.00_0.3859.t7.output.1.0].txt + if not parts[3] == ident: + continue + epoch = halves[0] + vloss = '.'.join([parts[0], parts[1]]) + temp = '.'.join([parts[4], parts[5]]) + cpname = 'lm_lstm_epoch' + epoch + '_' + vloss + '.t7' + cp_infos += [(fullpath, os.path.join(basedir, cpname), + (epoch, vloss, temp))] + return cp_infos + +def process_dir(basedir, targetdir, ident, copy_cp = False, verbose = False): + cp_infos = identify_checkpoints(basedir, ident) + for (dpath, cpath, (epoch, vloss, temp)) in cp_infos: + if verbose: + print('found dumpfile ' + dpath) + dname = basedir + '_epoch' + epoch + '_' + vloss + '.' + ident + '.' + temp + '.txt' + cname = basedir + '_epoch' + epoch + '_' + vloss + '.t7' + tdpath = os.path.join(targetdir, dname) + tcpath = os.path.join(targetdir, cname) + if verbose: + print('cp ' + dpath + ' ' + tdpath) + #shutil.copy(dpath, tdpath) + if copy_cp: + if os.path.isfile('cpath'): + if verbose: + print('cp ' + cpath + ' ' + tcpath) + #shutil.copy(cpath, tcpath) + + if copy_cp and len(cp_infos) > 0: + cmdpath = os.path.join(basedir, 'command.txt') + tcmdpath = os.path.join(targetdir, basedir + '.command') + if os.path.isfile('cpath'): + if verbose: + print('cp ' + cmdpath + ' ' + tcmdpath) + #shutil.copy(cmdpath, tcmdpath) + + for path in os.listdir(basedir): + fullpath = os.path.join(basedir, path) + if os.path.isdir(fullpath): + process_dir(fullpath, targetdir, ident, copy_cp=copy_cp, verbose=verbose) + +def main(basedir, targetdir, ident = 'output', copy_cp = False, verbose = False): + process_dir(basedir, targetdir, ident, copy_cp=copy_cp, verbose=verbose) + +if __name__ == '__main__': + import argparse + parser = argparse.ArgumentParser() + + parser.add_argument('basedir', #nargs='?'. default=None, + help='base rnn directory, must contain sample.lua') + parser.add_argument('targetdir', #nargs='?', default=None, + help='checkpoint directory, all subdirectories will be processed') + parser.add_argument('-c', '--copy_cp', action='store_true', + help='copy checkpoints used to generate the output files') + parser.add_argument('-i', '--ident', action='store', default='output', + help='identifier to look for to determine checkpoints') + parser.add_argument('-v', '--verbose', action='store_true', + help='verbose output') + + args = parser.parse_args() + main(args.basedir, args.targetdir, ident=args.ident, copy_cp=args.copy_cp, verbose=args.verbose) + exit(0)