51 lines
1.3 KiB
Python
51 lines
1.3 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
"""
|
||
|
Unit tests for nltk.classify. See also: nltk/test/classify.doctest
|
||
|
"""
|
||
|
from __future__ import absolute_import
|
||
|
from nose import SkipTest
|
||
|
from nltk import classify
|
||
|
|
||
|
TRAIN = [
|
||
|
(dict(a=1, b=1, c=1), 'y'),
|
||
|
(dict(a=1, b=1, c=1), 'x'),
|
||
|
(dict(a=1, b=1, c=0), 'y'),
|
||
|
(dict(a=0, b=1, c=1), 'x'),
|
||
|
(dict(a=0, b=1, c=1), 'y'),
|
||
|
(dict(a=0, b=0, c=1), 'y'),
|
||
|
(dict(a=0, b=1, c=0), 'x'),
|
||
|
(dict(a=0, b=0, c=0), 'x'),
|
||
|
(dict(a=0, b=1, c=1), 'y'),
|
||
|
]
|
||
|
|
||
|
TEST = [
|
||
|
(dict(a=1, b=0, c=1)), # unseen
|
||
|
(dict(a=1, b=0, c=0)), # unseen
|
||
|
(dict(a=0, b=1, c=1)), # seen 3 times, labels=y,y,x
|
||
|
(dict(a=0, b=1, c=0)), # seen 1 time, label=x
|
||
|
]
|
||
|
|
||
|
RESULTS = [(0.16, 0.84), (0.46, 0.54), (0.41, 0.59), (0.76, 0.24)]
|
||
|
|
||
|
|
||
|
def assert_classifier_correct(algorithm):
|
||
|
try:
|
||
|
classifier = classify.MaxentClassifier.train(
|
||
|
TRAIN, algorithm, trace=0, max_iter=1000
|
||
|
)
|
||
|
except (LookupError, AttributeError) as e:
|
||
|
raise SkipTest(str(e))
|
||
|
|
||
|
for (px, py), featureset in zip(RESULTS, TEST):
|
||
|
pdist = classifier.prob_classify(featureset)
|
||
|
assert abs(pdist.prob('x') - px) < 1e-2, (pdist.prob('x'), px)
|
||
|
assert abs(pdist.prob('y') - py) < 1e-2, (pdist.prob('y'), py)
|
||
|
|
||
|
|
||
|
def test_megam():
|
||
|
assert_classifier_correct('MEGAM')
|
||
|
|
||
|
|
||
|
def test_tadm():
|
||
|
assert_classifier_correct('TADM')
|