32 lines
988 B
Python
32 lines
988 B
Python
|
# -*- coding: utf-8 -*-
|
||
|
# Natural Language Toolkit: Language Model Unit Tests
|
||
|
#
|
||
|
# Copyright (C) 2001-2019 NLTK Project
|
||
|
# Author: Ilia Kurenkov <ilia.kurenkov@gmail.com>
|
||
|
# URL: <http://nltk.org/>
|
||
|
# For license information, see LICENSE.TXT
|
||
|
import unittest
|
||
|
|
||
|
from nltk.lm.preprocessing import padded_everygram_pipeline
|
||
|
|
||
|
|
||
|
class TestPreprocessing(unittest.TestCase):
|
||
|
def test_padded_everygram_pipeline(self):
|
||
|
expected_train = [
|
||
|
[
|
||
|
("<s>",),
|
||
|
("a",),
|
||
|
("b",),
|
||
|
("c",),
|
||
|
("</s>",),
|
||
|
("<s>", "a"),
|
||
|
("a", "b"),
|
||
|
("b", "c"),
|
||
|
("c", "</s>"),
|
||
|
]
|
||
|
]
|
||
|
expected_vocab = ["<s>", "a", "b", "c", "</s>"]
|
||
|
train_data, vocab_data = padded_everygram_pipeline(2, [["a", "b", "c"]])
|
||
|
self.assertEqual([list(sent) for sent in train_data], expected_train)
|
||
|
self.assertEqual(list(vocab_data), expected_vocab)
|