critical fix for cbow, some new management scripting work
This commit is contained in:
parent
4c625c1e29
commit
afd89730ad
3 changed files with 92 additions and 0 deletions
|
@ -65,6 +65,11 @@ def makevector(vocabulary,vecs,sequence):
|
||||||
res = v
|
res = v
|
||||||
else:
|
else:
|
||||||
res = [x + y for x, y in zip(res,v)]
|
res = [x + y for x, y in zip(res,v)]
|
||||||
|
|
||||||
|
# bad things happen if we have a vector of only unknown words
|
||||||
|
if res is None:
|
||||||
|
return [0.0]*len(vecs[0])
|
||||||
|
|
||||||
length = math.sqrt(sum([res[i] * res[i] for i in range(0,len(res))]))
|
length = math.sqrt(sum([res[i] * res[i] for i in range(0,len(res))]))
|
||||||
for i in range(0,len(res)):
|
for i in range(0,len(res)):
|
||||||
res[i] /= length
|
res[i] /= length
|
||||||
|
|
|
@ -88,3 +88,4 @@ if __name__ == '__main__':
|
||||||
seed = int(args.seed)
|
seed = int(args.seed)
|
||||||
main(args.rnndir, args.cpdir, float(args.temperature), int(args.count),
|
main(args.rnndir, args.cpdir, float(args.temperature), int(args.count),
|
||||||
seed=seed, ident=args.ident, verbose = args.verbose)
|
seed=seed, ident=args.ident, verbose = args.verbose)
|
||||||
|
exit(0)
|
||||||
|
|
86
scripts/collect_checkpoints.py
Executable file
86
scripts/collect_checkpoints.py
Executable file
|
@ -0,0 +1,86 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
def identify_checkpoints(basedir, ident):
|
||||||
|
cp_infos = []
|
||||||
|
for path in os.listdir(cpdir):
|
||||||
|
fullpath = os.path.join(basedir, path)
|
||||||
|
if not os.path.isfile(fullpath):
|
||||||
|
continue
|
||||||
|
if not (name[:13] == 'lm_lstm_epoch' and name[-4:] == '.txt'):
|
||||||
|
continue
|
||||||
|
if not ident in path:
|
||||||
|
continue
|
||||||
|
# attempt super hacky parsing
|
||||||
|
inner = path[13:-4]
|
||||||
|
halves = inner.split('_')
|
||||||
|
if not len(halves) == 2:
|
||||||
|
continue
|
||||||
|
parts = halves[1].split('.')
|
||||||
|
if not len(parts) == 6:
|
||||||
|
continue
|
||||||
|
# lm_lstm_epoch[25.00_0.3859.t7.output.1.0].txt
|
||||||
|
if not parts[3] == ident:
|
||||||
|
continue
|
||||||
|
epoch = halves[0]
|
||||||
|
vloss = '.'.join([parts[0], parts[1]])
|
||||||
|
temp = '.'.join([parts[4], parts[5]])
|
||||||
|
cpname = 'lm_lstm_epoch' + epoch + '_' + vloss + '.t7'
|
||||||
|
cp_infos += [(fullpath, os.path.join(basedir, cpname),
|
||||||
|
(epoch, vloss, temp))]
|
||||||
|
return cp_infos
|
||||||
|
|
||||||
|
def process_dir(basedir, targetdir, ident, copy_cp = False, verbose = False):
|
||||||
|
cp_infos = identify_checkpoints(basedir, ident)
|
||||||
|
for (dpath, cpath, (epoch, vloss, temp)) in cp_infos:
|
||||||
|
if verbose:
|
||||||
|
print('found dumpfile ' + dpath)
|
||||||
|
dname = basedir + '_epoch' + epoch + '_' + vloss + '.' + ident + '.' + temp + '.txt'
|
||||||
|
cname = basedir + '_epoch' + epoch + '_' + vloss + '.t7'
|
||||||
|
tdpath = os.path.join(targetdir, dname)
|
||||||
|
tcpath = os.path.join(targetdir, cname)
|
||||||
|
if verbose:
|
||||||
|
print('cp ' + dpath + ' ' + tdpath)
|
||||||
|
#shutil.copy(dpath, tdpath)
|
||||||
|
if copy_cp:
|
||||||
|
if os.path.isfile('cpath'):
|
||||||
|
if verbose:
|
||||||
|
print('cp ' + cpath + ' ' + tcpath)
|
||||||
|
#shutil.copy(cpath, tcpath)
|
||||||
|
|
||||||
|
if copy_cp and len(cp_infos) > 0:
|
||||||
|
cmdpath = os.path.join(basedir, 'command.txt')
|
||||||
|
tcmdpath = os.path.join(targetdir, basedir + '.command')
|
||||||
|
if os.path.isfile('cpath'):
|
||||||
|
if verbose:
|
||||||
|
print('cp ' + cmdpath + ' ' + tcmdpath)
|
||||||
|
#shutil.copy(cmdpath, tcmdpath)
|
||||||
|
|
||||||
|
for path in os.listdir(basedir):
|
||||||
|
fullpath = os.path.join(basedir, path)
|
||||||
|
if os.path.isdir(fullpath):
|
||||||
|
process_dir(fullpath, targetdir, ident, copy_cp=copy_cp, verbose=verbose)
|
||||||
|
|
||||||
|
def main(basedir, targetdir, ident = 'output', copy_cp = False, verbose = False):
|
||||||
|
process_dir(basedir, targetdir, ident, copy_cp=copy_cp, verbose=verbose)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
import argparse
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
parser.add_argument('basedir', #nargs='?'. default=None,
|
||||||
|
help='base rnn directory, must contain sample.lua')
|
||||||
|
parser.add_argument('targetdir', #nargs='?', default=None,
|
||||||
|
help='checkpoint directory, all subdirectories will be processed')
|
||||||
|
parser.add_argument('-c', '--copy_cp', action='store_true',
|
||||||
|
help='copy checkpoints used to generate the output files')
|
||||||
|
parser.add_argument('-i', '--ident', action='store', default='output',
|
||||||
|
help='identifier to look for to determine checkpoints')
|
||||||
|
parser.add_argument('-v', '--verbose', action='store_true',
|
||||||
|
help='verbose output')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args.basedir, args.targetdir, ident=args.ident, copy_cp=args.copy_cp, verbose=args.verbose)
|
||||||
|
exit(0)
|
Loading…
Reference in a new issue