108 lines
3.7 KiB
Python
108 lines
3.7 KiB
Python
|
from collections import defaultdict, Counter
|
||
|
from tqdm import tqdm
|
||
|
import nltk
|
||
|
import random
|
||
|
import pickle
|
||
|
import math
|
||
|
|
||
|
class Model():
|
||
|
|
||
|
def __init__(self, UNK_token = '<UNK>', n = 3):
|
||
|
self.n = n
|
||
|
self.UNK_token = UNK_token
|
||
|
self.ngrams = defaultdict(defaultdict(int).copy)
|
||
|
self.contexts = defaultdict(int)
|
||
|
self.tokenizer = { UNK_token: 0 }
|
||
|
self.reverse_tokenizer = { 0: UNK_token }
|
||
|
self._tokenizer_index = 1
|
||
|
self.vocab = set()
|
||
|
|
||
|
self.n_split = self.n // 2
|
||
|
|
||
|
def train_tokenizer(self, corpus: list) -> list[int]:
|
||
|
for word in tqdm(corpus):
|
||
|
if word not in self.vocab:
|
||
|
self.vocab.add(word)
|
||
|
self.tokenizer[word] = self._tokenizer_index
|
||
|
self.reverse_tokenizer[self._tokenizer_index] = word
|
||
|
|
||
|
self._tokenizer_index += 1
|
||
|
|
||
|
def tokenize(self, corpus: list, verbose = False) -> list[int]:
|
||
|
result = []
|
||
|
|
||
|
for word in tqdm(corpus) if verbose else corpus:
|
||
|
if word not in self.vocab:
|
||
|
result.append(self.tokenizer[self.UNK_token])
|
||
|
else:
|
||
|
result.append(self.tokenizer[word])
|
||
|
|
||
|
return result
|
||
|
|
||
|
def train(self, corpus: list) -> None:
|
||
|
|
||
|
print("Training tokenizer")
|
||
|
self.train_tokenizer(corpus)
|
||
|
|
||
|
print("Tokenizing corpus")
|
||
|
corpus = self.tokenize(corpus, verbose = True)
|
||
|
|
||
|
print("Saving n-grams")
|
||
|
n_grams = list(nltk.ngrams(corpus, self.n))
|
||
|
for gram in tqdm(n_grams):
|
||
|
left_context = gram[:self.n_split]
|
||
|
right_context = gram[self.n_split + 1:]
|
||
|
word = gram[self.n_split]
|
||
|
|
||
|
if word == self.UNK_token:
|
||
|
continue
|
||
|
|
||
|
self.ngrams[(left_context, right_context)][word] += 1
|
||
|
self.contexts[(left_context, right_context)] += 1
|
||
|
|
||
|
def get_conditional_probability_for_word(self, left_context: list, right_context: list, word: str) -> float:
|
||
|
left_context = tuple(left_context[-self.n_split:])
|
||
|
right_context = tuple(right_context[:self.n_split])
|
||
|
|
||
|
total_count = self.contexts[(left_context, right_context)]
|
||
|
|
||
|
if total_count == 0:
|
||
|
return 0.0
|
||
|
else:
|
||
|
word_count = self.ngrams[(left_context, right_context)][word]
|
||
|
|
||
|
return word_count / total_count
|
||
|
|
||
|
def get_probabilities(self, left_context: list, right_context: list) -> float:
|
||
|
left_context = tuple(left_context[-self.n_split:])
|
||
|
right_context = tuple(right_context[:self.n_split])
|
||
|
|
||
|
words = list(self.ngrams[(left_context, right_context)].keys())
|
||
|
probs = []
|
||
|
|
||
|
for word in words:
|
||
|
prob = self.get_conditional_probability_for_word(left_context, right_context, word)
|
||
|
probs.append((word, prob))
|
||
|
|
||
|
return sorted(probs, reverse = True, key = lambda x: x[0])[:10]
|
||
|
|
||
|
def fill_gap(self, left_context: list, right_context: list) -> list:
|
||
|
left_context = self.tokenize(left_context)
|
||
|
right_context = self.tokenize(right_context)
|
||
|
|
||
|
result = []
|
||
|
probabilities = self.get_probabilities(left_context, right_context)
|
||
|
for probability in probabilities:
|
||
|
word = self.reverse_tokenizer[probability[0]]
|
||
|
result.append((word, probability[1]))
|
||
|
|
||
|
return result
|
||
|
|
||
|
def save(self, output_dir: str) -> None:
|
||
|
with open(output_dir, 'wb') as f:
|
||
|
pickle.dump(self, f)
|
||
|
|
||
|
@staticmethod
|
||
|
def load(model_path: str) -> 'Model':
|
||
|
with open(model_path, 'rb') as f:
|
||
|
return pickle.load(f)
|