challenging-america-word-ga.../src/model.py

111 lines
3.8 KiB
Python
Raw Normal View History

2024-04-28 00:56:04 +02:00
from collections import defaultdict, Counter
from tqdm import tqdm
import nltk
import random
import pickle
from multiprocessing import Pool
import math
from bidict import bidict
class Model():
def __init__(self, UNK_token = '<UNK>', n = 3):
self.n = n
self.UNK_token = UNK_token
self.ngrams = defaultdict(defaultdict(int).copy)
self.contexts = defaultdict(int)
self.tokenizer = bidict({ UNK_token: 0 })
self._tokenizer_index = 1
self.vocab = set()
self.n_split = self.n // 2
def train_tokenizer(self, corpus: list) -> list[int]:
for word in tqdm(corpus):
if word not in self.vocab:
self.vocab.add(word)
self.tokenizer[word] = self._tokenizer_index
self._tokenizer_index += 1
def tokenize(self, corpus: list, verbose = False) -> list[int]:
result = []
for word in tqdm(corpus) if verbose else corpus:
if word not in self.vocab:
result.append(self.tokenizer[self.UNK_token])
else:
result.append(self.tokenizer[word])
return result
def process_gram(self, gram: tuple) -> tuple:
left_context = gram[:self.n_split]
right_context = gram[self.n_split + 1:]
word = gram[self.n_split]
if word == self.UNK_token:
return
self.ngrams[(left_context, right_context)][word] += 1
self.contexts[(left_context, right_context)] += 1
def train(self, corpus: list) -> None:
print("Training tokenizer")
self.train_tokenizer(corpus)
print("Tokenizing corpus")
corpus = self.tokenize(corpus, verbose = True)
print("Saving n-grams")
n_grams = list(nltk.ngrams(corpus, self.n))
for gram in tqdm(n_grams):
self.process_gram(gram)
def get_conditional_probability_for_word(self, left_context: list, right_context: list, word: str) -> float:
left_context = tuple(left_context[-self.n_split:])
right_context = tuple(right_context[:self.n_split])
total_count = self.contexts[(left_context, right_context)]
if total_count == 0:
return 0.0
else:
word_count = self.ngrams[(left_context, right_context)][word]
return word_count / total_count
def get_probabilities(self, left_context: list, right_context: list) -> float:
left_context = tuple(left_context[-self.n_split:])
right_context = tuple(right_context[:self.n_split])
words = list(self.ngrams[(left_context, right_context)].keys())
probs = []
for word in words:
prob = self.get_conditional_probability_for_word(left_context, right_context, word)
probs.append((word, prob))
return sorted(probs, reverse = True, key = lambda x: x[1])[:10]
def fill_gap(self, left_context: list, right_context: list) -> list:
left_context = self.tokenize(left_context)
right_context = self.tokenize(right_context)
result = []
probabilities = self.get_probabilities(left_context, right_context)
for token, probability in probabilities:
word = self.tokenizer.inverse[token]
result.append((word, probability))
return result
def save(self, output_dir: str) -> None:
with open(output_dir, 'wb') as f:
pickle.dump(self, f)
@staticmethod
def load(model_path: str) -> 'Model':
with open(model_path, 'rb') as f:
2024-04-24 19:46:10 +02:00
return pickle.load(f)