434624
This commit is contained in:
parent
a7ec11ca27
commit
14d3dc0e04
4
.gitignore
vendored
4
.gitignore
vendored
@ -1,4 +1,4 @@
|
||||
|
||||
geval
|
||||
*~
|
||||
*.swp
|
||||
*.bak
|
||||
@ -6,3 +6,5 @@
|
||||
*.o
|
||||
.DS_Store
|
||||
.token
|
||||
*.pickle
|
||||
*.xz
|
10519
dev-0/out-embed-100.tsv
Normal file
10519
dev-0/out-embed-100.tsv
Normal file
File diff suppressed because it is too large
Load Diff
10519
dev-0/out-embed-500.tsv
Normal file
10519
dev-0/out-embed-500.tsv
Normal file
File diff suppressed because it is too large
Load Diff
@ -5,6 +5,9 @@ tags:
|
||||
params:
|
||||
epochs: 1
|
||||
vocab-size: 20000
|
||||
batch-size: 5000
|
||||
embed-size: 100
|
||||
topk: 150
|
||||
batch-size: 10000
|
||||
embed-size:
|
||||
- 100
|
||||
- 500
|
||||
- 1000
|
||||
topk: 10
|
||||
|
148
run.py
148
run.py
@ -1,5 +1,3 @@
|
||||
from itertools import islice
|
||||
import sys
|
||||
import lzma
|
||||
import regex as re
|
||||
from torchtext.vocab import build_vocab_from_iterator
|
||||
@ -16,10 +14,12 @@ from tqdm import tqdm
|
||||
|
||||
def get_words_from_line(line):
|
||||
line = line.rstrip()
|
||||
yield "<s>"
|
||||
for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line):
|
||||
yield m.group(0).lower()
|
||||
yield "</s>"
|
||||
line = line.split("\t")
|
||||
text = line[-2] + " " + line[-1]
|
||||
text = re.sub(r"\\\\+n", " ", text)
|
||||
text = re.sub('[^A-Za-z ]+', '', text)
|
||||
for t in text.split():
|
||||
yield t
|
||||
|
||||
|
||||
def get_word_lines_from_file(file_name):
|
||||
@ -64,25 +64,28 @@ class TrigramModel(nn.Module):
|
||||
def __init__(self, vocab_size, embedding_dim, hidden_dim):
|
||||
super(TrigramModel, self).__init__()
|
||||
self.embeddings = nn.Embedding(vocab_size, embedding_dim)
|
||||
self.linear1 = nn.Linear(embedding_dim, hidden_dim)
|
||||
self.linear2 = nn.Linear(hidden_dim, vocab_size)
|
||||
self.hidden = nn.Linear(embedding_dim * 2, hidden_dim)
|
||||
self.output = nn.Linear(hidden_dim, vocab_size)
|
||||
self.softmax = nn.Softmax()
|
||||
|
||||
def forward(self, x, y):
|
||||
x = self.embeddings(x)
|
||||
y = self.embeddings(y)
|
||||
z = self.linear1(x + y)
|
||||
z = self.linear2(z)
|
||||
z = self.hidden(torch.cat([x, y], dim=1))
|
||||
z = self.output(z)
|
||||
z = self.softmax(z)
|
||||
return z
|
||||
|
||||
|
||||
embed_size = 500
|
||||
vocab_size = 20000
|
||||
vocab_path = "vocabulary.pickle"
|
||||
if exists(vocab_path):
|
||||
print("Loading vocabulary from file...")
|
||||
with open(vocab_path, "rb") as fh:
|
||||
vocab = pickle.load(fh)
|
||||
else:
|
||||
print("Building vocabulary...")
|
||||
vocab = build_vocab_from_iterator(
|
||||
get_word_lines_from_file("train/in.tsv.xz"),
|
||||
max_tokens=vocab_size,
|
||||
@ -92,16 +95,32 @@ else:
|
||||
with open(vocab_path, "wb") as fh:
|
||||
pickle.dump(vocab, fh)
|
||||
|
||||
device = "cpu"
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
print("Using device:", device)
|
||||
dataset_path = 'train/dataset.pickle'
|
||||
if exists(dataset_path):
|
||||
print("Loading dataset from file...")
|
||||
with open(dataset_path, "rb") as fh:
|
||||
train_dataset = pickle.load(fh)
|
||||
else:
|
||||
print("Building dataset...")
|
||||
train_dataset = Trigrams("train/in.tsv.xz", vocab_size)
|
||||
model = TrigramModel(vocab_size, 100, 64).to(device)
|
||||
data = DataLoader(train_dataset, batch_size=5000)
|
||||
with open(dataset_path, "wb") as fh:
|
||||
pickle.dump(train_dataset, fh)
|
||||
|
||||
print("Building model...")
|
||||
model = TrigramModel(vocab_size, embed_size, 64).to(device)
|
||||
data = DataLoader(train_dataset, batch_size=10000)
|
||||
optimizer = torch.optim.Adam(model.parameters())
|
||||
criterion = torch.nn.NLLLoss()
|
||||
|
||||
print("Training model...")
|
||||
model.train()
|
||||
losses = []
|
||||
for epoch in tqdm(range(10)):
|
||||
step = 0
|
||||
max_steps = 1000
|
||||
|
||||
for x, y, z in tqdm(data):
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
@ -110,10 +129,105 @@ for epoch in tqdm(range(10)):
|
||||
optimizer.zero_grad()
|
||||
ypredicted = model(x, z)
|
||||
loss = criterion(torch.log(ypredicted), y)
|
||||
losses.append(loss)
|
||||
losses.append(loss.item())
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
print(f"Epoch {epoch} loss:", loss.item())
|
||||
step += 1
|
||||
if step > max_steps:
|
||||
break
|
||||
|
||||
plt.plot(losses)
|
||||
torch.save(model.state_dict(), "model1.bin")
|
||||
plt.show()
|
||||
|
||||
torch.save(model.state_dict(), f"trigram_model-embed_{embed_size}.bin")
|
||||
|
||||
vocab_unique = set(train_dataset.vocab.get_stoi().keys())
|
||||
|
||||
output = []
|
||||
print('Predicting dev...')
|
||||
with lzma.open("dev-0/in.tsv.xz", encoding='utf8', mode="rt") as file:
|
||||
for line in tqdm(file):
|
||||
line = line.split("\t")
|
||||
|
||||
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
|
||||
first_word = re.sub('[^A-Za-z]+', '', first_word)
|
||||
|
||||
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
|
||||
nenxt_word = re.sub('[^A-Za-z]+', '', next_word)
|
||||
|
||||
if first_word not in vocab_unique:
|
||||
word = "<unk>"
|
||||
if next_word not in vocab_unique:
|
||||
word = "<unk>"
|
||||
|
||||
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
|
||||
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
|
||||
|
||||
out = model(first_word, next_word)
|
||||
|
||||
top = torch.topk(out[0], 10)
|
||||
top_indices = top.indices.tolist()
|
||||
top_probs = top.values.tolist()
|
||||
unk_bonus = 1 - sum(top_probs)
|
||||
top_words = vocab.lookup_tokens(top_indices)
|
||||
top_zipped = list(zip(top_words, top_probs))
|
||||
|
||||
res = ""
|
||||
for w, p in top_zipped:
|
||||
if w == "<unk>":
|
||||
res += f":{(p + unk_bonus):.4f} "
|
||||
else:
|
||||
res += f"{w}:{p:.4f} "
|
||||
|
||||
res = res[:-1]
|
||||
res += "\n"
|
||||
output.append(res)
|
||||
|
||||
with open(f"dev-0/out-embed-{embed_size}.tsv", mode="w") as file:
|
||||
file.writelines(output)
|
||||
|
||||
|
||||
model.eval()
|
||||
|
||||
output = []
|
||||
print('Predicting test...')
|
||||
with lzma.open("test-A/in.tsv.xz", encoding='utf8', mode="rt") as file:
|
||||
for line in tqdm(file):
|
||||
line = line.split("\t")
|
||||
|
||||
first_word = re.sub(r"\\\\+n", " ", line[-2]).split()[-1]
|
||||
first_word = re.sub('[^A-Za-z]+', '', first_word)
|
||||
|
||||
next_word = re.sub(r"\\\\+n", " ", line[-1]).split()[0]
|
||||
next_word = re.sub('[^A-Za-z]+', '', next_word)
|
||||
|
||||
if first_word not in vocab_unique:
|
||||
word = "<unk>"
|
||||
if next_word not in vocab_unique:
|
||||
word = "<unk>"
|
||||
|
||||
first_word = torch.tensor(train_dataset.vocab.forward([first_word])).to(device)
|
||||
next_word = torch.tensor(train_dataset.vocab.forward([next_word])).to(device)
|
||||
|
||||
out = model(first_word, next_word)
|
||||
|
||||
top = torch.topk(out[0], 10)
|
||||
top_indices = top.indices.tolist()
|
||||
top_probs = top.values.tolist()
|
||||
unk_bonus = 1 - sum(top_probs)
|
||||
top_words = vocab.lookup_tokens(top_indices)
|
||||
top_zipped = list(zip(top_words, top_probs))
|
||||
|
||||
res = ""
|
||||
for w, p in top_zipped:
|
||||
if w == "<unk>":
|
||||
res += f":{(p + unk_bonus):.4f} "
|
||||
else:
|
||||
res += f"{w}:{p:.4f} "
|
||||
|
||||
res = res[:-1]
|
||||
res += "\n"
|
||||
output.append(res)
|
||||
|
||||
with open(f"test-A/out-embed-{embed_size}.tsv", mode="w") as file:
|
||||
file.writelines(output)
|
||||
|
7414
test-A/out-embed-100.tsv
Normal file
7414
test-A/out-embed-100.tsv
Normal file
File diff suppressed because it is too large
Load Diff
7414
test-A/out-embed-500.tsv
Normal file
7414
test-A/out-embed-500.tsv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
trigram_model-50_steps-embed_100.bin
Normal file
BIN
trigram_model-50_steps-embed_100.bin
Normal file
Binary file not shown.
BIN
trigram_model-embed_100.bin
Normal file
BIN
trigram_model-embed_100.bin
Normal file
Binary file not shown.
BIN
trigram_model-embed_500.bin
Normal file
BIN
trigram_model-embed_500.bin
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user