425 lines
15 KiB
Python
425 lines
15 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Natural Language Toolkit: Transformation-based learning
|
|
#
|
|
# Copyright (C) 2001-2019 NLTK Project
|
|
# Author: Marcus Uneson <marcus.uneson@gmail.com>
|
|
# based on previous (nltk2) version by
|
|
# Christopher Maloof, Edward Loper, Steven Bird
|
|
# URL: <http://nltk.org/>
|
|
# For license information, see LICENSE.TXT
|
|
|
|
from __future__ import print_function, absolute_import, division
|
|
import os
|
|
import pickle
|
|
|
|
import random
|
|
import time
|
|
|
|
from nltk.corpus import treebank
|
|
|
|
from nltk.tbl import error_list, Template
|
|
from nltk.tag.brill import Word, Pos
|
|
from nltk.tag import BrillTaggerTrainer, RegexpTagger, UnigramTagger
|
|
|
|
|
|
def demo():
|
|
"""
|
|
Run a demo with defaults. See source comments for details,
|
|
or docstrings of any of the more specific demo_* functions.
|
|
"""
|
|
postag()
|
|
|
|
|
|
def demo_repr_rule_format():
|
|
"""
|
|
Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
|
|
"""
|
|
postag(ruleformat="repr")
|
|
|
|
|
|
def demo_str_rule_format():
|
|
"""
|
|
Exemplify repr(Rule) (see also str(Rule) and Rule.format("verbose"))
|
|
"""
|
|
postag(ruleformat="str")
|
|
|
|
|
|
def demo_verbose_rule_format():
|
|
"""
|
|
Exemplify Rule.format("verbose")
|
|
"""
|
|
postag(ruleformat="verbose")
|
|
|
|
|
|
def demo_multiposition_feature():
|
|
"""
|
|
The feature/s of a template takes a list of positions
|
|
relative to the current word where the feature should be
|
|
looked for, conceptually joined by logical OR. For instance,
|
|
Pos([-1, 1]), given a value V, will hold whenever V is found
|
|
one step to the left and/or one step to the right.
|
|
|
|
For contiguous ranges, a 2-arg form giving inclusive end
|
|
points can also be used: Pos(-3, -1) is the same as the arg
|
|
below.
|
|
"""
|
|
postag(templates=[Template(Pos([-3, -2, -1]))])
|
|
|
|
|
|
def demo_multifeature_template():
|
|
"""
|
|
Templates can have more than a single feature.
|
|
"""
|
|
postag(templates=[Template(Word([0]), Pos([-2, -1]))])
|
|
|
|
|
|
def demo_template_statistics():
|
|
"""
|
|
Show aggregate statistics per template. Little used templates are
|
|
candidates for deletion, much used templates may possibly be refined.
|
|
|
|
Deleting unused templates is mostly about saving time and/or space:
|
|
training is basically O(T) in the number of templates T
|
|
(also in terms of memory usage, which often will be the limiting factor).
|
|
"""
|
|
postag(incremental_stats=True, template_stats=True)
|
|
|
|
|
|
def demo_generated_templates():
|
|
"""
|
|
Template.expand and Feature.expand are class methods facilitating
|
|
generating large amounts of templates. See their documentation for
|
|
details.
|
|
|
|
Note: training with 500 templates can easily fill all available
|
|
even on relatively small corpora
|
|
"""
|
|
wordtpls = Word.expand([-1, 0, 1], [1, 2], excludezero=False)
|
|
tagtpls = Pos.expand([-2, -1, 0, 1], [1, 2], excludezero=True)
|
|
templates = list(Template.expand([wordtpls, tagtpls], combinations=(1, 3)))
|
|
print(
|
|
"Generated {0} templates for transformation-based learning".format(
|
|
len(templates)
|
|
)
|
|
)
|
|
postag(templates=templates, incremental_stats=True, template_stats=True)
|
|
|
|
|
|
def demo_learning_curve():
|
|
"""
|
|
Plot a learning curve -- the contribution on tagging accuracy of
|
|
the individual rules.
|
|
Note: requires matplotlib
|
|
"""
|
|
postag(
|
|
incremental_stats=True,
|
|
separate_baseline_data=True,
|
|
learning_curve_output="learningcurve.png",
|
|
)
|
|
|
|
|
|
def demo_error_analysis():
|
|
"""
|
|
Writes a file with context for each erroneous word after tagging testing data
|
|
"""
|
|
postag(error_output="errors.txt")
|
|
|
|
|
|
def demo_serialize_tagger():
|
|
"""
|
|
Serializes the learned tagger to a file in pickle format; reloads it
|
|
and validates the process.
|
|
"""
|
|
postag(serialize_output="tagger.pcl")
|
|
|
|
|
|
def demo_high_accuracy_rules():
|
|
"""
|
|
Discard rules with low accuracy. This may hurt performance a bit,
|
|
but will often produce rules which are more interesting read to a human.
|
|
"""
|
|
postag(num_sents=3000, min_acc=0.96, min_score=10)
|
|
|
|
|
|
def postag(
|
|
templates=None,
|
|
tagged_data=None,
|
|
num_sents=1000,
|
|
max_rules=300,
|
|
min_score=3,
|
|
min_acc=None,
|
|
train=0.8,
|
|
trace=3,
|
|
randomize=False,
|
|
ruleformat="str",
|
|
incremental_stats=False,
|
|
template_stats=False,
|
|
error_output=None,
|
|
serialize_output=None,
|
|
learning_curve_output=None,
|
|
learning_curve_take=300,
|
|
baseline_backoff_tagger=None,
|
|
separate_baseline_data=False,
|
|
cache_baseline_tagger=None,
|
|
):
|
|
"""
|
|
Brill Tagger Demonstration
|
|
:param templates: how many sentences of training and testing data to use
|
|
:type templates: list of Template
|
|
|
|
:param tagged_data: maximum number of rule instances to create
|
|
:type tagged_data: C{int}
|
|
|
|
:param num_sents: how many sentences of training and testing data to use
|
|
:type num_sents: C{int}
|
|
|
|
:param max_rules: maximum number of rule instances to create
|
|
:type max_rules: C{int}
|
|
|
|
:param min_score: the minimum score for a rule in order for it to be considered
|
|
:type min_score: C{int}
|
|
|
|
:param min_acc: the minimum score for a rule in order for it to be considered
|
|
:type min_acc: C{float}
|
|
|
|
:param train: the fraction of the the corpus to be used for training (1=all)
|
|
:type train: C{float}
|
|
|
|
:param trace: the level of diagnostic tracing output to produce (0-4)
|
|
:type trace: C{int}
|
|
|
|
:param randomize: whether the training data should be a random subset of the corpus
|
|
:type randomize: C{bool}
|
|
|
|
:param ruleformat: rule output format, one of "str", "repr", "verbose"
|
|
:type ruleformat: C{str}
|
|
|
|
:param incremental_stats: if true, will tag incrementally and collect stats for each rule (rather slow)
|
|
:type incremental_stats: C{bool}
|
|
|
|
:param template_stats: if true, will print per-template statistics collected in training and (optionally) testing
|
|
:type template_stats: C{bool}
|
|
|
|
:param error_output: the file where errors will be saved
|
|
:type error_output: C{string}
|
|
|
|
:param serialize_output: the file where the learned tbl tagger will be saved
|
|
:type serialize_output: C{string}
|
|
|
|
:param learning_curve_output: filename of plot of learning curve(s) (train and also test, if available)
|
|
:type learning_curve_output: C{string}
|
|
|
|
:param learning_curve_take: how many rules plotted
|
|
:type learning_curve_take: C{int}
|
|
|
|
:param baseline_backoff_tagger: the file where rules will be saved
|
|
:type baseline_backoff_tagger: tagger
|
|
|
|
:param separate_baseline_data: use a fraction of the training data exclusively for training baseline
|
|
:type separate_baseline_data: C{bool}
|
|
|
|
:param cache_baseline_tagger: cache baseline tagger to this file (only interesting as a temporary workaround to get
|
|
deterministic output from the baseline unigram tagger between python versions)
|
|
:type cache_baseline_tagger: C{string}
|
|
|
|
|
|
Note on separate_baseline_data: if True, reuse training data both for baseline and rule learner. This
|
|
is fast and fine for a demo, but is likely to generalize worse on unseen data.
|
|
Also cannot be sensibly used for learning curves on training data (the baseline will be artificially high).
|
|
"""
|
|
|
|
# defaults
|
|
baseline_backoff_tagger = baseline_backoff_tagger or REGEXP_TAGGER
|
|
if templates is None:
|
|
from nltk.tag.brill import describe_template_sets, brill24
|
|
|
|
# some pre-built template sets taken from typical systems or publications are
|
|
# available. Print a list with describe_template_sets()
|
|
# for instance:
|
|
templates = brill24()
|
|
(training_data, baseline_data, gold_data, testing_data) = _demo_prepare_data(
|
|
tagged_data, train, num_sents, randomize, separate_baseline_data
|
|
)
|
|
|
|
# creating (or reloading from cache) a baseline tagger (unigram tagger)
|
|
# this is just a mechanism for getting deterministic output from the baseline between
|
|
# python versions
|
|
if cache_baseline_tagger:
|
|
if not os.path.exists(cache_baseline_tagger):
|
|
baseline_tagger = UnigramTagger(
|
|
baseline_data, backoff=baseline_backoff_tagger
|
|
)
|
|
with open(cache_baseline_tagger, 'w') as print_rules:
|
|
pickle.dump(baseline_tagger, print_rules)
|
|
print(
|
|
"Trained baseline tagger, pickled it to {0}".format(
|
|
cache_baseline_tagger
|
|
)
|
|
)
|
|
with open(cache_baseline_tagger, "r") as print_rules:
|
|
baseline_tagger = pickle.load(print_rules)
|
|
print("Reloaded pickled tagger from {0}".format(cache_baseline_tagger))
|
|
else:
|
|
baseline_tagger = UnigramTagger(baseline_data, backoff=baseline_backoff_tagger)
|
|
print("Trained baseline tagger")
|
|
if gold_data:
|
|
print(
|
|
" Accuracy on test set: {0:0.4f}".format(
|
|
baseline_tagger.evaluate(gold_data)
|
|
)
|
|
)
|
|
|
|
# creating a Brill tagger
|
|
tbrill = time.time()
|
|
trainer = BrillTaggerTrainer(
|
|
baseline_tagger, templates, trace, ruleformat=ruleformat
|
|
)
|
|
print("Training tbl tagger...")
|
|
brill_tagger = trainer.train(training_data, max_rules, min_score, min_acc)
|
|
print("Trained tbl tagger in {0:0.2f} seconds".format(time.time() - tbrill))
|
|
if gold_data:
|
|
print(" Accuracy on test set: %.4f" % brill_tagger.evaluate(gold_data))
|
|
|
|
# printing the learned rules, if learned silently
|
|
if trace == 1:
|
|
print("\nLearned rules: ")
|
|
for (ruleno, rule) in enumerate(brill_tagger.rules(), 1):
|
|
print("{0:4d} {1:s}".format(ruleno, rule.format(ruleformat)))
|
|
|
|
# printing template statistics (optionally including comparison with the training data)
|
|
# note: if not separate_baseline_data, then baseline accuracy will be artificially high
|
|
if incremental_stats:
|
|
print(
|
|
"Incrementally tagging the test data, collecting individual rule statistics"
|
|
)
|
|
(taggedtest, teststats) = brill_tagger.batch_tag_incremental(
|
|
testing_data, gold_data
|
|
)
|
|
print(" Rule statistics collected")
|
|
if not separate_baseline_data:
|
|
print(
|
|
"WARNING: train_stats asked for separate_baseline_data=True; the baseline "
|
|
"will be artificially high"
|
|
)
|
|
trainstats = brill_tagger.train_stats()
|
|
if template_stats:
|
|
brill_tagger.print_template_statistics(teststats)
|
|
if learning_curve_output:
|
|
_demo_plot(
|
|
learning_curve_output, teststats, trainstats, take=learning_curve_take
|
|
)
|
|
print("Wrote plot of learning curve to {0}".format(learning_curve_output))
|
|
else:
|
|
print("Tagging the test data")
|
|
taggedtest = brill_tagger.tag_sents(testing_data)
|
|
if template_stats:
|
|
brill_tagger.print_template_statistics()
|
|
|
|
# writing error analysis to file
|
|
if error_output is not None:
|
|
with open(error_output, 'w') as f:
|
|
f.write('Errors for Brill Tagger %r\n\n' % serialize_output)
|
|
f.write(
|
|
u'\n'.join(error_list(gold_data, taggedtest)).encode('utf-8') + '\n'
|
|
)
|
|
print("Wrote tagger errors including context to {0}".format(error_output))
|
|
|
|
# serializing the tagger to a pickle file and reloading (just to see it works)
|
|
if serialize_output is not None:
|
|
taggedtest = brill_tagger.tag_sents(testing_data)
|
|
with open(serialize_output, 'w') as print_rules:
|
|
pickle.dump(brill_tagger, print_rules)
|
|
print("Wrote pickled tagger to {0}".format(serialize_output))
|
|
with open(serialize_output, "r") as print_rules:
|
|
brill_tagger_reloaded = pickle.load(print_rules)
|
|
print("Reloaded pickled tagger from {0}".format(serialize_output))
|
|
taggedtest_reloaded = brill_tagger.tag_sents(testing_data)
|
|
if taggedtest == taggedtest_reloaded:
|
|
print("Reloaded tagger tried on test set, results identical")
|
|
else:
|
|
print("PROBLEM: Reloaded tagger gave different results on test set")
|
|
|
|
|
|
def _demo_prepare_data(
|
|
tagged_data, train, num_sents, randomize, separate_baseline_data
|
|
):
|
|
# train is the proportion of data used in training; the rest is reserved
|
|
# for testing.
|
|
if tagged_data is None:
|
|
print("Loading tagged data from treebank... ")
|
|
tagged_data = treebank.tagged_sents()
|
|
if num_sents is None or len(tagged_data) <= num_sents:
|
|
num_sents = len(tagged_data)
|
|
if randomize:
|
|
random.seed(len(tagged_data))
|
|
random.shuffle(tagged_data)
|
|
cutoff = int(num_sents * train)
|
|
training_data = tagged_data[:cutoff]
|
|
gold_data = tagged_data[cutoff:num_sents]
|
|
testing_data = [[t[0] for t in sent] for sent in gold_data]
|
|
if not separate_baseline_data:
|
|
baseline_data = training_data
|
|
else:
|
|
bl_cutoff = len(training_data) // 3
|
|
(baseline_data, training_data) = (
|
|
training_data[:bl_cutoff],
|
|
training_data[bl_cutoff:],
|
|
)
|
|
(trainseqs, traintokens) = corpus_size(training_data)
|
|
(testseqs, testtokens) = corpus_size(testing_data)
|
|
(bltrainseqs, bltraintokens) = corpus_size(baseline_data)
|
|
print("Read testing data ({0:d} sents/{1:d} wds)".format(testseqs, testtokens))
|
|
print("Read training data ({0:d} sents/{1:d} wds)".format(trainseqs, traintokens))
|
|
print(
|
|
"Read baseline data ({0:d} sents/{1:d} wds) {2:s}".format(
|
|
bltrainseqs,
|
|
bltraintokens,
|
|
"" if separate_baseline_data else "[reused the training set]",
|
|
)
|
|
)
|
|
return (training_data, baseline_data, gold_data, testing_data)
|
|
|
|
|
|
def _demo_plot(learning_curve_output, teststats, trainstats=None, take=None):
|
|
testcurve = [teststats['initialerrors']]
|
|
for rulescore in teststats['rulescores']:
|
|
testcurve.append(testcurve[-1] - rulescore)
|
|
testcurve = [1 - x / teststats['tokencount'] for x in testcurve[:take]]
|
|
|
|
traincurve = [trainstats['initialerrors']]
|
|
for rulescore in trainstats['rulescores']:
|
|
traincurve.append(traincurve[-1] - rulescore)
|
|
traincurve = [1 - x / trainstats['tokencount'] for x in traincurve[:take]]
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
r = list(range(len(testcurve)))
|
|
plt.plot(r, testcurve, r, traincurve)
|
|
plt.axis([None, None, None, 1.0])
|
|
plt.savefig(learning_curve_output)
|
|
|
|
|
|
NN_CD_TAGGER = RegexpTagger([(r'^-?[0-9]+(.[0-9]+)?$', 'CD'), (r'.*', 'NN')])
|
|
|
|
REGEXP_TAGGER = RegexpTagger(
|
|
[
|
|
(r'^-?[0-9]+(.[0-9]+)?$', 'CD'), # cardinal numbers
|
|
(r'(The|the|A|a|An|an)$', 'AT'), # articles
|
|
(r'.*able$', 'JJ'), # adjectives
|
|
(r'.*ness$', 'NN'), # nouns formed from adjectives
|
|
(r'.*ly$', 'RB'), # adverbs
|
|
(r'.*s$', 'NNS'), # plural nouns
|
|
(r'.*ing$', 'VBG'), # gerunds
|
|
(r'.*ed$', 'VBD'), # past tense verbs
|
|
(r'.*', 'NN'), # nouns (default)
|
|
]
|
|
)
|
|
|
|
|
|
def corpus_size(seqs):
|
|
return (len(seqs), sum(len(x) for x in seqs))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
demo_learning_curve()
|