136 lines
4.2 KiB
Python
136 lines
4.2 KiB
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Natural Language Toolkit: Language Model Unit Tests
|
||
|
#
|
||
|
# Copyright (C) 2001-2019 NLTK Project
|
||
|
# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
|
||
|
# URL: <http://nltk.org/>
|
||
|
# For license information, see LICENSE.TXT
|
||
|
|
||
|
import unittest
|
||
|
|
||
|
import six
|
||
|
|
||
|
from nltk import FreqDist
|
||
|
from nltk.lm import NgramCounter
|
||
|
from nltk.util import everygrams
|
||
|
|
||
|
|
||
|
class NgramCounterTests(unittest.TestCase):
|
||
|
"""Tests for NgramCounter that only involve lookup, no modification."""
|
||
|
|
||
|
@classmethod
|
||
|
def setUpClass(cls):
|
||
|
|
||
|
text = [list("abcd"), list("egdbe")]
|
||
|
cls.trigram_counter = NgramCounter(
|
||
|
(everygrams(sent, max_len=3) for sent in text)
|
||
|
)
|
||
|
cls.bigram_counter = NgramCounter(
|
||
|
(everygrams(sent, max_len=2) for sent in text)
|
||
|
)
|
||
|
|
||
|
def test_N(self):
|
||
|
self.assertEqual(self.bigram_counter.N(), 16)
|
||
|
self.assertEqual(self.trigram_counter.N(), 21)
|
||
|
|
||
|
def test_counter_len_changes_with_lookup(self):
|
||
|
self.assertEqual(len(self.bigram_counter), 2)
|
||
|
_ = self.bigram_counter[50]
|
||
|
self.assertEqual(len(self.bigram_counter), 3)
|
||
|
|
||
|
def test_ngram_order_access_unigrams(self):
|
||
|
self.assertEqual(self.bigram_counter[1], self.bigram_counter.unigrams)
|
||
|
|
||
|
def test_ngram_conditional_freqdist(self):
|
||
|
expected_trigram_contexts = [
|
||
|
("a", "b"),
|
||
|
("b", "c"),
|
||
|
("e", "g"),
|
||
|
("g", "d"),
|
||
|
("d", "b"),
|
||
|
]
|
||
|
expected_bigram_contexts = [("a",), ("b",), ("d",), ("e",), ("c",), ("g",)]
|
||
|
|
||
|
bigrams = self.trigram_counter[2]
|
||
|
trigrams = self.trigram_counter[3]
|
||
|
|
||
|
six.assertCountEqual(self, expected_bigram_contexts, bigrams.conditions())
|
||
|
six.assertCountEqual(self, expected_trigram_contexts, trigrams.conditions())
|
||
|
|
||
|
def test_bigram_counts_seen_ngrams(self):
|
||
|
b_given_a_count = 1
|
||
|
unk_given_b_count = 1
|
||
|
|
||
|
self.assertEqual(b_given_a_count, self.bigram_counter[["a"]]["b"])
|
||
|
self.assertEqual(unk_given_b_count, self.bigram_counter[["b"]]["c"])
|
||
|
|
||
|
def test_bigram_counts_unseen_ngrams(self):
|
||
|
z_given_b_count = 0
|
||
|
|
||
|
self.assertEqual(z_given_b_count, self.bigram_counter[["b"]]["z"])
|
||
|
|
||
|
def test_unigram_counts_seen_words(self):
|
||
|
expected_count_b = 2
|
||
|
|
||
|
self.assertEqual(expected_count_b, self.bigram_counter["b"])
|
||
|
|
||
|
def test_unigram_counts_completely_unseen_words(self):
|
||
|
unseen_count = 0
|
||
|
|
||
|
self.assertEqual(unseen_count, self.bigram_counter["z"])
|
||
|
|
||
|
|
||
|
class NgramCounterTrainingTests(unittest.TestCase):
|
||
|
def setUp(self):
|
||
|
self.counter = NgramCounter()
|
||
|
|
||
|
def test_empty_string(self):
|
||
|
test = NgramCounter("")
|
||
|
self.assertNotIn(2, test)
|
||
|
self.assertEqual(test[1], FreqDist())
|
||
|
|
||
|
def test_empty_list(self):
|
||
|
test = NgramCounter([])
|
||
|
self.assertNotIn(2, test)
|
||
|
self.assertEqual(test[1], FreqDist())
|
||
|
|
||
|
def test_None(self):
|
||
|
test = NgramCounter(None)
|
||
|
self.assertNotIn(2, test)
|
||
|
self.assertEqual(test[1], FreqDist())
|
||
|
|
||
|
def test_train_on_unigrams(self):
|
||
|
words = list("abcd")
|
||
|
counter = NgramCounter([[(w,) for w in words]])
|
||
|
|
||
|
self.assertFalse(bool(counter[3]))
|
||
|
self.assertFalse(bool(counter[2]))
|
||
|
six.assertCountEqual(self, words, counter[1].keys())
|
||
|
|
||
|
def test_train_on_illegal_sentences(self):
|
||
|
str_sent = ["Check", "this", "out", "!"]
|
||
|
list_sent = [["Check", "this"], ["this", "out"], ["out", "!"]]
|
||
|
|
||
|
with self.assertRaises(TypeError):
|
||
|
NgramCounter([str_sent])
|
||
|
|
||
|
with self.assertRaises(TypeError):
|
||
|
NgramCounter([list_sent])
|
||
|
|
||
|
def test_train_on_bigrams(self):
|
||
|
bigram_sent = [("a", "b"), ("c", "d")]
|
||
|
counter = NgramCounter([bigram_sent])
|
||
|
|
||
|
self.assertFalse(bool(counter[3]))
|
||
|
|
||
|
def test_train_on_mix(self):
|
||
|
mixed_sent = [("a", "b"), ("c", "d"), ("e", "f", "g"), ("h",)]
|
||
|
counter = NgramCounter([mixed_sent])
|
||
|
unigrams = ["h"]
|
||
|
bigram_contexts = [("a",), ("c",)]
|
||
|
trigram_contexts = [("e", "f")]
|
||
|
|
||
|
six.assertCountEqual(self, unigrams, counter[1].keys())
|
||
|
six.assertCountEqual(self, bigram_contexts, counter[2].keys())
|
||
|
six.assertCountEqual(self, trigram_contexts, counter[3].keys())
|