This commit is contained in:
Jakub Kaczmarek 2023-05-10 02:52:02 +02:00
parent a7ec11ca27
commit 14d3dc0e04
10 changed files with 36016 additions and 31 deletions

4
.gitignore vendored
View File

@ -1,4 +1,4 @@
geval
*~
*.swp
*.bak
@ -6,3 +6,5 @@
*.o
.DS_Store
.token
*.pickle
*.xz

10519
dev-0/out-embed-100.tsv Normal file

File diff suppressed because it is too large Load Diff

10519
dev-0/out-embed-500.tsv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -5,6 +5,9 @@ tags:
params:
epochs: 1
vocab-size: 20000
batch-size: 5000
embed-size: 100
topk: 150
batch-size: 10000
embed-size:
- 100
- 500
- 1000
topk: 10

152
run.py
View File

@ -1,5 +1,3 @@
from itertools import islice
import sys
import lzma
import regex as re
from torchtext.vocab import build_vocab_from_iterator
@ -16,10 +14,12 @@ from tqdm import tqdm
def get_words_from_line(line):
line = line.rstrip()
yield "<s>"
for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line):
yield m.group(0).lower()
yield "</s>"
line = line.split("\t")
text = line[-2] + " " + line[-1]
text = re.sub(r"\\\\+n", " ", text)
text = re.sub('[^A-Za-z ]+', '', text)
for t in text.split():
yield t
def get_word_lines_from_file(file_name):
@ -64,25 +64,28 @@ class TrigramModel(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
super(TrigramModel, self).__init__()
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
self.linear1 = nn.Linear(embedding_dim, hidden_dim)
self.linear2 = nn.Linear(hidden_dim, vocab_size)
self.hidden = nn.Linear(embedding_dim * 2, hidden_dim)
self.output = nn.Linear(hidden_dim, vocab_size)
self.softmax = nn.Softmax()
def forward(self, x, y):
x = self.embeddings(x)
y = self.embeddings(y)
z = self.linear1(x + y)
z = self.linear2(z)
z = self.hidden(torch.cat([x, y], dim=1))
z = self.output(z)
z = self.softmax(z)
return z
embed_size = 500
vocab_size = 20000
vocab_path = "vocabulary.pickle"
if exists(vocab_path):
print("Loading vocabulary from file...")
with open(vocab_path, "rb") as fh:
vocab = pickle.load(fh)
else:
print("Building vocabulary...")
vocab = build_vocab_from_iterator(
get_word_lines_from_file("train/in.tsv.xz"),
max_tokens=vocab_size,
@ -92,17 +95,33 @@ else:
with open(vocab_path, "wb") as fh:
pickle.dump(vocab, fh)
device = "cpu"
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
model = TrigramModel(vocab_size, 100, 64).to(device)
data = DataLoader(train_dataset, batch_size=5000)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
dataset_path = 'train/dataset.pickle'
if exists(dataset_path):
print("Loading dataset from file...")
with open(dataset_path, "rb") as fh:
train_dataset = pickle.load(fh)
else:
print("Building dataset...")
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
with open(dataset_path, "wb") as fh:
pickle.dump(train_dataset, fh)
print("Building model...")
model = TrigramModel(vocab_size, embed_size, 64).to(device)
data = DataLoader(train_dataset, batch_size=10000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
print("Training model...")
model.train()
losses = []
for epoch in tqdm(range(10)):
for x, y, z in tqdm(data):
step = 0
max_steps = 1000
for x, y, z in tqdm(data):
x = x.to(device)
y = y.to(device)
z = z.to(device)
@ -110,10 +129,105 @@ for epoch in tqdm(range(10)):
optimizer.zero_grad()
ypredicted = model(x, z)
loss = criterion(torch.log(ypredicted), y)
losses.append(loss)
losses.append(loss.item())
loss.backward()
optimizer.step()
print(f"Epoch {epoch} loss:", loss.item())
step += 1
if step > max_steps:
break
plt.plot(losses)
torch.save(model.state_dict(), "model1.bin")
plt.show()
torch.save(model.state_dict(), f"trigram_model-embed_{embed_size}.bin")
vocab_unique = set(train_dataset.vocab.get_stoi().keys())
output = []
print('Predicting dev...')
with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in tqdm(file):
line = line.split("\t")
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
first_word = re.sub('[^A-Za-z]+', '', first_word)
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
nenxt_word = re.sub('[^A-Za-z]+', '', next_word)
if first_word not in vocab_unique:
word = "<unk>"
if next_word not in vocab_unique:
word = "<unk>"
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
out = model(first_word, next_word)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = ""
for w, p in top_zipped:
if w == "<unk>":
res += f":{(p + unk_bonus):.4f} "
else:
res += f"{w}:{p:.4f} "
res = res[:-1]
res += "\n"
output.append(res)
with open(f"dev-0/out-embed-{embed_size}.tsv", mode="w") as file:
file.writelines(output)
model.eval()
output = []
print('Predicting test...')
with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file:
for line in tqdm(file):
line = line.split("\t")
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
first_word = re.sub('[^A-Za-z]+', '', first_word)
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
next_word = re.sub('[^A-Za-z]+', '', next_word)
if first_word not in vocab_unique:
word = "<unk>"
if next_word not in vocab_unique:
word = "<unk>"
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
out = model(first_word, next_word)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
unk_bonus = 1 - sum(top_probs)
top_words = vocab.lookup_tokens(top_indices)
top_zipped = list(zip(top_words, top_probs))
res = ""
for w, p in top_zipped:
if w == "<unk>":
res += f":{(p + unk_bonus):.4f} "
else:
res += f"{w}:{p:.4f} "
res = res[:-1]
res += "\n"
output.append(res)
with open(f"test-A/out-embed-{embed_size}.tsv", mode="w") as file:
file.writelines(output)

7414
test-A/out-embed-100.tsv Normal file

File diff suppressed because it is too large Load Diff

7414
test-A/out-embed-500.tsv Normal file

File diff suppressed because it is too large Load Diff

Binary file not shown.

BIN
trigram_model-embed_100.bin Normal file

Binary file not shown.

BIN
trigram_model-embed_500.bin Normal file

Binary file not shown.