Compare commits
2 Commits
Author | SHA1 | Date | |
---|---|---|---|
21dad8fc72 | |||
2e7c5b13c0 |
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
*.dict
|
||||
data
|
1600
dev-0/out.tsv
1600
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
24
lang.py
Normal file
24
lang.py
Normal file
@ -0,0 +1,24 @@
|
||||
from nltk.tokenize import RegexpTokenizer
|
||||
SOS_token = 2
|
||||
PAD_token = 0
|
||||
|
||||
class Lang:
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
self.word2index = {}
|
||||
self.word2count = {}
|
||||
self.index2word = {0: "PAD", 1: "UNK", 2: "SOS"}
|
||||
self.n_words = 2
|
||||
|
||||
def addSentence(self, sentence):
|
||||
for word in sentence.split():
|
||||
self.addWord(word)
|
||||
|
||||
def addWord(self, word):
|
||||
if word not in self.word2index:
|
||||
self.word2index[word] = self.n_words
|
||||
self.word2count[word] = 1
|
||||
self.index2word[self.n_words] = word
|
||||
self.n_words += 1
|
||||
else:
|
||||
self.word2count[word] += 1
|
38
lstm_model.py
Normal file
38
lstm_model.py
Normal file
@ -0,0 +1,38 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
device = 'cuda'
|
||||
import torch.nn.functional as F
|
||||
|
||||
class EncoderRNN(nn.Module):
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(EncoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.embedding = nn.Embedding(input_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||
|
||||
def forward(self, input, hidden):
|
||||
embedded = self.embedding(input)
|
||||
output = embedded
|
||||
output, hidden = self.lstm(output, hidden)
|
||||
return output, hidden
|
||||
|
||||
def initHidden(self):
|
||||
return (torch.zeros(1, 1, self.hidden_size, device=device), torch.zeros(1, 1, self.hidden_size, device=device))
|
||||
|
||||
class DecoderRNN(nn.Module):
|
||||
def __init__(self, hidden_size, output_size):
|
||||
super(DecoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
self.embedding = nn.Embedding(output_size, hidden_size)
|
||||
self.lstm = nn.LSTM(hidden_size, hidden_size)
|
||||
self.out = nn.Linear(hidden_size, output_size)
|
||||
self.softmax = nn.LogSoftmax(dim=1)
|
||||
|
||||
def forward(self, input, hidden):
|
||||
output = self.embedding(input)
|
||||
output = F.relu(output)
|
||||
output, hidden = self.lstm(output, hidden)
|
||||
output = self.softmax(self.out(output[0]))
|
||||
return output, hidden
|
130
model_train.py
Normal file
130
model_train.py
Normal file
@ -0,0 +1,130 @@
|
||||
from lang import SOS_token
|
||||
import torch
|
||||
import random
|
||||
import math
|
||||
import time
|
||||
from torch import nn, optim
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch
|
||||
import pickle
|
||||
|
||||
MAX_LENGTH = 25
|
||||
device = 'cuda'
|
||||
teacher_forcing_ratio = 0.8
|
||||
|
||||
with open('data/pairs.pkl', 'rb') as input_file:
|
||||
pairs = pickle.load(input_file)
|
||||
|
||||
with open('data/pl_lang.pkl', 'rb') as input_file:
|
||||
input_lang = pickle.load(input_file)
|
||||
|
||||
with open('data/en_lang.pkl', 'rb') as out_file:
|
||||
output_lang = pickle.load(out_file)
|
||||
|
||||
|
||||
def indexesFromSentence(lang, sentence):
|
||||
return [lang.word2index[word] if word in lang.word2index else 1 for word in sentence]
|
||||
|
||||
|
||||
def tensorFromSentence(lang, sentence):
|
||||
indexes = indexesFromSentence(lang, sentence)
|
||||
indexes.append(0)
|
||||
out = torch.tensor(indexes, device=device).view(-1, 1)
|
||||
return out
|
||||
|
||||
|
||||
def tensorsFromPair(pair):
|
||||
input_tensor = tensorFromSentence(input_lang, pair[0])
|
||||
target_tensor = tensorFromSentence(output_lang, pair[1])
|
||||
return (input_tensor, target_tensor)
|
||||
|
||||
|
||||
def asMinutes(s):
|
||||
m = math.floor(s / 60)
|
||||
s -= m * 60
|
||||
return '%dm %ds' % (m, s)
|
||||
|
||||
def timeSince(since, percent):
|
||||
now = time.time()
|
||||
s = now - since
|
||||
es = s / (percent)
|
||||
rs = es - s
|
||||
return '%s (- %s)' % (asMinutes(s), asMinutes(rs))
|
||||
|
||||
def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
|
||||
encoder_hidden = encoder.initHidden()
|
||||
encoder_optimizer.zero_grad()
|
||||
decoder_optimizer.zero_grad()
|
||||
|
||||
input_length = input_tensor.size(0)
|
||||
target_length = target_tensor.size(0)
|
||||
|
||||
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||
loss = 0
|
||||
encoder_output, encoder_hidden = encoder(input_tensor, encoder_hidden)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
decoder_input = torch.tensor([[SOS_token]], device=device)
|
||||
decoder_hidden = encoder_hidden
|
||||
use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False
|
||||
|
||||
if use_teacher_forcing:
|
||||
for di in range(target_length):
|
||||
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||
loss += criterion(decoder_output, target_tensor[di])
|
||||
decoder_input = target_tensor[di].unsqueeze(0)
|
||||
|
||||
else:
|
||||
for di in range(target_length):
|
||||
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||
topv, topi = decoder_output.topk(1)
|
||||
decoder_input = topi.transpose(0, 1).detach()
|
||||
loss += criterion(decoder_output, target_tensor[di])
|
||||
loss.backward()
|
||||
|
||||
encoder_optimizer.step()
|
||||
decoder_optimizer.step()
|
||||
|
||||
return loss.item() / target_length
|
||||
|
||||
def trainIters(encoder, decoder, n_iters, print_every=10, plot_every=100, learning_rate=0.01):
|
||||
start = time.time()
|
||||
print_loss_total = 0 # Reset every print_every
|
||||
plot_loss_total = 0 # Reset every plot_every
|
||||
|
||||
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
|
||||
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)
|
||||
criterion = nn.NLLLoss()
|
||||
pairs_in = pairs[:10000]
|
||||
for iter in range(1, n_iters + 1):
|
||||
try:
|
||||
for idx, training_pair in enumerate(pairs_in):
|
||||
input_ = training_pair[0]
|
||||
target_ = training_pair[1]
|
||||
input_ = input_.split()
|
||||
input_ = input_[::-1]
|
||||
target_ = target_.split()
|
||||
|
||||
if len(input_)>1 and len(target_)>1:
|
||||
input_tensor = tensorFromSentence(input_lang, input_)
|
||||
target_tensor = tensorFromSentence(output_lang, target_)
|
||||
|
||||
loss = train(input_tensor, target_tensor, encoder,
|
||||
decoder, encoder_optimizer, decoder_optimizer, criterion)
|
||||
print_loss_total += loss
|
||||
plot_loss_total += loss
|
||||
print(idx/len(pairs_in), end='\r')
|
||||
|
||||
if iter % print_every == 0:
|
||||
print_loss_avg = print_loss_total / print_every
|
||||
print_loss_total = 0
|
||||
print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),
|
||||
iter, iter / n_iters * 100, print_loss_avg))
|
||||
except KeyboardInterrupt:
|
||||
torch.save(encoder.state_dict(), 'encoder.dict')
|
||||
torch.save(decoder.state_dict(), 'decoder.dict')
|
||||
torch.save(encoder.state_dict(), 'encoder.dict')
|
||||
torch.save(decoder.state_dict(), 'decoder.dict')
|
56
predict.py
Normal file
56
predict.py
Normal file
@ -0,0 +1,56 @@
|
||||
from model_train import tensorFromSentence, SOS_token, MAX_LENGTH, device
|
||||
import pickle
|
||||
from lstm_model import EncoderRNN, DecoderRNN
|
||||
import sys
|
||||
import torch
|
||||
|
||||
with open('data/pl_lang.pkl', 'rb') as input_file:
|
||||
input_lang = pickle.load(input_file)
|
||||
|
||||
with open('data/en_lang.pkl', 'rb') as out_file:
|
||||
output_lang = pickle.load(out_file)
|
||||
|
||||
|
||||
def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):
|
||||
with torch.no_grad():
|
||||
input_tensor = tensorFromSentence(input_lang, sentence)
|
||||
encoder_hidden = encoder.initHidden()
|
||||
|
||||
encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)
|
||||
|
||||
encoder_output, encoder_hidden = encoder(input_tensor,encoder_hidden)
|
||||
|
||||
encoder_outputs = encoder_output
|
||||
|
||||
decoder_input = torch.tensor([[SOS_token]], device=device)
|
||||
|
||||
decoder_hidden = encoder_hidden
|
||||
|
||||
decoded_words = []
|
||||
|
||||
for di in range(max_length):
|
||||
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden)
|
||||
topv, topi = decoder_output.topk(1)
|
||||
decoded_words.append(topi)
|
||||
decoder_input = topi.transpose(0, 1)
|
||||
out = torch.stack(decoded_words)
|
||||
return out
|
||||
|
||||
hidden_size = 256
|
||||
encoder = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
||||
decoder = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
||||
encoder.load_state_dict(torch.load('encoder.dict'))
|
||||
decoder.load_state_dict(torch.load('decoder.dict'))
|
||||
encoder.eval()
|
||||
decoder.eval()
|
||||
for line in sys.stdin:
|
||||
line = line.rstrip()
|
||||
dec_words = evaluate(encoder, decoder, line, MAX_LENGTH)
|
||||
dec_words = dec_words.transpose(0, 1)
|
||||
for sen in dec_words:
|
||||
out = []
|
||||
for idx in sen:
|
||||
if idx == 0:
|
||||
break
|
||||
out.append(output_lang.index2word[idx.item()])
|
||||
print(' '.join(out))
|
90
prepare.py
Normal file
90
prepare.py
Normal file
@ -0,0 +1,90 @@
|
||||
from __future__ import unicode_literals, print_function, division
|
||||
from io import open
|
||||
import unicodedata
|
||||
import string
|
||||
import re
|
||||
import random
|
||||
import pickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import optim
|
||||
import torch.nn.functional as F
|
||||
from lang import *
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
MAX_LENGTH = 25
|
||||
|
||||
# Turn a Unicode string to plain ASCII, thanks to
|
||||
# https://stackoverflow.com/a/518232/2809427
|
||||
def unicodeToAscii(s):
|
||||
return ''.join(
|
||||
c for c in unicodedata.normalize('NFD', s)
|
||||
if unicodedata.category(c) != 'Mn'
|
||||
)
|
||||
|
||||
# Lowercase, trim, and remove non-letter characters
|
||||
|
||||
def filterPair(p):
|
||||
return len(tokenizer.tokenize(p[0])) < MAX_LENGTH and \
|
||||
len(tokenizer.tokenize(p[1])) < MAX_LENGTH
|
||||
|
||||
|
||||
def filterPairs(pairs):
|
||||
return [pair for pair in pairs if filterPair(pair)]
|
||||
|
||||
def normalizeString(s):
|
||||
s = re.sub(r"([.!?])", r" \1", s)
|
||||
return s
|
||||
|
||||
def readLangs(lang1, lang2, reverse=False):
|
||||
print("Reading lines...")
|
||||
lines_pl = []
|
||||
lines_en = []
|
||||
# Read the file and split into lines
|
||||
with open('train/in.tsv', 'r', encoding='utf-8') as pl:
|
||||
for line in pl:
|
||||
line = line.rstrip()
|
||||
lines_pl.append(line)
|
||||
with open('train/expected.tsv', 'r', encoding='utf-8') as en:
|
||||
for line in en:
|
||||
line = line.rstrip()
|
||||
lines_en.append(line)
|
||||
|
||||
# Split every line into pairs and normalize
|
||||
pairs = []
|
||||
for p, e in zip(lines_pl, lines_en):
|
||||
pl_s = normalizeString(p)
|
||||
pl_e = normalizeString(e)
|
||||
pairs.append([pl_e, pl_s])
|
||||
if reverse:
|
||||
pairs = [list(reversed(p)) for p in pairs]
|
||||
input_lang = Lang(lang2)
|
||||
output_lang = Lang(lang1)
|
||||
else:
|
||||
input_lang = Lang(lang1)
|
||||
output_lang = Lang(lang2)
|
||||
|
||||
return input_lang, output_lang, pairs
|
||||
|
||||
def prepareData(lang1, lang2, reverse=False):
|
||||
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
|
||||
print("Read %s sentence pairs" % len(pairs))
|
||||
pairs = filterPairs(pairs)
|
||||
print("Trimmed to %s sentence pairs" % len(pairs))
|
||||
print("Counting words...")
|
||||
for pair in pairs:
|
||||
input_lang.addSentence(pair[0])
|
||||
output_lang.addSentence(pair[1])
|
||||
print("Counted words:")
|
||||
print(input_lang.name, input_lang.n_words)
|
||||
print(output_lang.name, output_lang.n_words)
|
||||
return input_lang, output_lang, pairs
|
||||
|
||||
|
||||
input_lang, output_lang, pairs = prepareData('pl', 'eng', True)
|
||||
print(random.choice(pairs))
|
||||
with open('data/pairs.pkl', 'wb+') as p:
|
||||
pickle.dump(pairs, p, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with open('data/pl_lang.pkl', 'wb+') as p:
|
||||
pickle.dump(input_lang, p, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
with open('data/en_lang.pkl', 'wb+') as p:
|
||||
pickle.dump(output_lang, p, protocol=pickle.HIGHEST_PROTOCOL)
|
@ -1,30 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from transformers import MarianTokenizer, MarianMTModel
|
||||
import sys
|
||||
from typing import List
|
||||
from numba import jit
|
||||
|
||||
@jit
|
||||
def count():
|
||||
data={}
|
||||
for doc_id,line in enumerate(sys.stdin):
|
||||
data[doc_id]=line.rstrip()
|
||||
return data
|
||||
|
||||
def translate(data):
|
||||
for key in data.keys():
|
||||
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
|
||||
gen = model.generate(**batch)
|
||||
translate = tok.batch_decode(gen, skip_special_tokens=True)
|
||||
print(translate[0])
|
||||
|
||||
if __name__ =="__main__":
|
||||
src = 'pl' # source language
|
||||
trg = 'en' # target language
|
||||
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
#print('Data ready!')
|
||||
model = MarianMTModel.from_pretrained(mname)
|
||||
tok = MarianTokenizer.from_pretrained(mname)
|
||||
data=count()
|
||||
translate(data)
|
3600
test-A/out.tsv
3600
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
7
train.py
Normal file
7
train.py
Normal file
@ -0,0 +1,7 @@
|
||||
from lstm_model import EncoderRNN, DecoderRNN
|
||||
from model_train import *
|
||||
hidden_size = 256
|
||||
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
||||
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
||||
|
||||
trainIters(encoder1, attn_decoder1, 5, print_every=1)
|
632600
train/expected.tsv
Normal file
632600
train/expected.tsv
Normal file
File diff suppressed because it is too large
Load Diff
632600
train/in.tsv
Normal file
632600
train/in.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user