minor fixes and analysis script to collect data for a checkpoint
This commit is contained in:
parent
7c8ebe7015
commit
00159593bb
3 changed files with 158 additions and 3 deletions
154
scripts/analysis.py
Executable file
154
scripts/analysis.py
Executable file
|
@ -0,0 +1,154 @@
|
|||
#!/usr/bin/env python
|
||||
import sys
|
||||
import os
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
# scipy is kinda necessary
|
||||
import scipy
|
||||
import scipy.stats
|
||||
import numpy as np
|
||||
|
||||
def gmean_nonzero(l):
|
||||
filtered = [x for x in l if x != 0]
|
||||
return scipy.stats.gmean(filtered)
|
||||
|
||||
libdir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../lib')
|
||||
sys.path.append(libdir)
|
||||
datadir = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../data')
|
||||
import jdecode
|
||||
|
||||
import validate
|
||||
import ngrams
|
||||
|
||||
def annotate_values(values):
|
||||
for k in values:
|
||||
(total, good, bad) = values[k]
|
||||
values[k] = OrderedDict([('total', total), ('good', good), ('bad', bad)])
|
||||
return values
|
||||
|
||||
def print_statistics(stats, ident = 0):
|
||||
for k in stats:
|
||||
if isinstance(stats[k], OrderedDict):
|
||||
print(' ' * ident + str(k) + ':')
|
||||
print_statistics(stats[k], ident=ident+2)
|
||||
elif isinstance(stats[k], dict):
|
||||
print(' ' * ident + str(k) + ': <dict with ' + str(len(stats[k])) + ' entries>')
|
||||
elif isinstance(stats[k], list):
|
||||
print(' ' * ident + str(k) + ': <list with ' + str(len(stats[k])) + ' entries>')
|
||||
else:
|
||||
print(' ' * ident + str(k) + ': ' + str(stats[k]))
|
||||
|
||||
def get_statistics(fname, lm = None, sep = False, verbose=False):
|
||||
stats = OrderedDict()
|
||||
cards = jdecode.mtg_open_file(fname, verbose=verbose)
|
||||
stats['cards'] = cards
|
||||
|
||||
# unpack the name of the checkpoint - terrible and hacky
|
||||
try:
|
||||
final_name = os.path.basename(fname)
|
||||
halves = final_name.split('_epoch')
|
||||
cp_name = halves[0]
|
||||
cp_info = halves[1][:-4]
|
||||
info_halves = cp_info.split('_')
|
||||
cp_epoch = float(info_halves[0])
|
||||
fragments = info_halves[1].split('.')
|
||||
cp_vloss = float('.'.join(fragments[:2]))
|
||||
cp_temp = float('.'.join(fragments[-2:]))
|
||||
cp_ident = '.'.join(fragments[2:-2])
|
||||
stats['cp'] = OrderedDict([('name', cp_name),
|
||||
('epoch', cp_epoch),
|
||||
('vloss', cp_vloss),
|
||||
('temp', cp_temp),
|
||||
('ident', cp_ident)])
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
# validate
|
||||
((total_all, total_good, total_bad, total_uncovered),
|
||||
values) = validate.process_props(cards)
|
||||
|
||||
stats['props'] = annotate_values(values)
|
||||
stats['props']['overall'] = OrderedDict([('total', total_all),
|
||||
('good', total_good),
|
||||
('bad', total_bad),
|
||||
('uncovered', total_uncovered)])
|
||||
|
||||
# distances
|
||||
distfname = fname + '.dist'
|
||||
if os.path.isfile(distfname):
|
||||
name_dupes = 0
|
||||
card_dupes = 0
|
||||
with open(distfname, 'rt') as f:
|
||||
distlines = f.read().split('\n')
|
||||
dists = OrderedDict([('name', []), ('cbow', [])])
|
||||
for line in distlines:
|
||||
fields = line.split('|')
|
||||
if len(fields) < 4:
|
||||
continue
|
||||
idx = int(fields[0])
|
||||
name = str(fields[1])
|
||||
ndist = float(fields[2])
|
||||
cdist = float(fields[3])
|
||||
dists['name'] += [ndist]
|
||||
dists['cbow'] += [cdist]
|
||||
if ndist == 1.0:
|
||||
name_dupes += 1
|
||||
if cdist == 1.0:
|
||||
card_dupes += 1
|
||||
|
||||
dists['name_mean'] = np.mean(dists['name'])
|
||||
dists['cbow_mean'] = np.mean(dists['cbow'])
|
||||
dists['name_geomean'] = gmean_nonzero(dists['name'])
|
||||
dists['cbow_geomean'] = gmean_nonzero(dists['cbow'])
|
||||
stats['dists'] = dists
|
||||
|
||||
# n-grams
|
||||
if not lm is None:
|
||||
ngram = OrderedDict([('perp', []), ('perp_per', [])])
|
||||
for card in cards:
|
||||
if len(card.text.text) == 0:
|
||||
perp = 0.0
|
||||
perp_per = 0.0
|
||||
elif sep:
|
||||
vtexts = [line.vectorize().split() for line in card.text_lines
|
||||
if len(line.vectorize().split()) > 0]
|
||||
perps = [lm.perplexity(vtext) for vtext in vtexts]
|
||||
perps_per = [perps[i] / float(len(vtexts[i])) for i in range(0, len(vtexts))]
|
||||
perp = gmean_nonzero(perps)
|
||||
perp_per = gmean_nonzero(perps_per)
|
||||
else:
|
||||
vtext = card.text.vectorize().split()
|
||||
perp = lm.perplexity(vtext)
|
||||
perp_per = perp / float(len(vtext))
|
||||
|
||||
ngram['perp'] += [perp]
|
||||
ngram['perp_per'] += [perp_per]
|
||||
|
||||
ngram['perp_mean'] = np.mean(ngram['perp'])
|
||||
ngram['perp_per_mean'] = np.mean(ngram['perp_per'])
|
||||
ngram['perp_geomean'] = gmean_nonzero(ngram['perp'])
|
||||
ngram['perp_per_geomean'] = gmean_nonzero(ngram['perp_per'])
|
||||
stats['ngram'] = ngram
|
||||
|
||||
print_statistics(stats)
|
||||
|
||||
|
||||
def main(infile, verbose = False):
|
||||
lm = ngrams.build_ngram_model(jdecode.mtg_open_file(str(os.path.join(datadir, 'output.txt'))),
|
||||
3, separate_lines=True, verbose=True)
|
||||
get_statistics(infile, lm=lm, sep=True, verbose=verbose)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('infile', #nargs='?'. default=None,
|
||||
help='encoded card file or json corpus to process')
|
||||
parser.add_argument('-v', '--verbose', action='store_true',
|
||||
help='verbose output')
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.infile, verbose=args.verbose)
|
||||
exit(0)
|
|
@ -49,7 +49,7 @@ def build_ngram_model(cards, n, separate_lines = True, verbose = False):
|
|||
lang = extract_language(cards, separate_lines=separate_lines)
|
||||
if verbose:
|
||||
print('found ' + str(len(lang)) + ' sentences')
|
||||
lm = model.NgramModel(n, lang)
|
||||
lm = model.NgramModel(n, lang, pad_left=True, pad_right=True)
|
||||
if verbose:
|
||||
print(lm)
|
||||
return lm
|
||||
|
|
|
@ -326,8 +326,9 @@ def check_shuffle(card):
|
|||
def check_quotes(card):
|
||||
retval = None
|
||||
for line in card.text_lines:
|
||||
# NOTE: the '" pattern in the training set is actually incorrect
|
||||
quotes = len(re.findall(re.escape('"'), line.text))
|
||||
# HACK: the '" pattern in the training set is actually incorrect
|
||||
quotes += len(re.findall(re.escape('\'"'), line.text))
|
||||
if quotes > 0:
|
||||
thisval = quotes % 2 == 0
|
||||
if retval is None:
|
||||
|
@ -399,7 +400,7 @@ def process_props(cards, dump = False, uncovered = False):
|
|||
return ((total_all, total_good, total_bad, total_uncovered),
|
||||
values)
|
||||
|
||||
def main(fname, oname = None, verbose = True, dump = False):
|
||||
def main(fname, oname = None, verbose = False, dump = False):
|
||||
# may need to set special arguments here
|
||||
cards = jdecode.mtg_open_file(fname, verbose=verbose)
|
||||
|
||||
|
|
Loading…
Reference in a new issue