91 lines
2.8 KiB
Python
91 lines
2.8 KiB
Python
from __future__ import unicode_literals, print_function, division
|
|
from io import open
|
|
import unicodedata
|
|
import string
|
|
import re
|
|
import random
|
|
import pickle
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch import optim
|
|
import torch.nn.functional as F
|
|
from lang import *
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
MAX_LENGTH = 300
|
|
|
|
# Turn a Unicode string to plain ASCII, thanks to
|
|
# https://stackoverflow.com/a/518232/2809427
|
|
def unicodeToAscii(s):
|
|
return ''.join(
|
|
c for c in unicodedata.normalize('NFD', s)
|
|
if unicodedata.category(c) != 'Mn'
|
|
)
|
|
|
|
# Lowercase, trim, and remove non-letter characters
|
|
|
|
def filterPair(p):
|
|
return len(tokenizer.tokenize(p[0])) < MAX_LENGTH and \
|
|
len(tokenizer.tokenize(p[1])) < MAX_LENGTH
|
|
|
|
|
|
def filterPairs(pairs):
|
|
return [pair for pair in pairs if filterPair(pair)]
|
|
|
|
def normalizeString(s):
|
|
s = re.sub(r"([.!?])", r" \1", s)
|
|
return s
|
|
|
|
def readLangs(lang1, lang2, reverse=False):
|
|
print("Reading lines...")
|
|
lines_pl = []
|
|
lines_en = []
|
|
# Read the file and split into lines
|
|
with open('train/in.tsv', 'r', encoding='utf-8') as pl:
|
|
for line in pl:
|
|
line = line.rstrip()
|
|
lines_pl.append(line)
|
|
with open('train/expected.tsv', 'r', encoding='utf-8') as en:
|
|
for line in en:
|
|
line = line.rstrip()
|
|
lines_en.append(line)
|
|
|
|
# Split every line into pairs and normalize
|
|
pairs = []
|
|
for p, e in zip(lines_pl, lines_en):
|
|
pl_s = normalizeString(p)
|
|
pl_e = normalizeString(e)
|
|
pairs.append([pl_e, pl_s])
|
|
if reverse:
|
|
pairs = [list(reversed(p)) for p in pairs]
|
|
input_lang = Lang(lang2)
|
|
output_lang = Lang(lang1)
|
|
else:
|
|
input_lang = Lang(lang1)
|
|
output_lang = Lang(lang2)
|
|
|
|
return input_lang, output_lang, pairs
|
|
|
|
def prepareData(lang1, lang2, reverse=False):
|
|
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
|
|
print("Read %s sentence pairs" % len(pairs))
|
|
pairs = filterPairs(pairs)
|
|
print("Trimmed to %s sentence pairs" % len(pairs))
|
|
print("Counting words...")
|
|
for pair in pairs:
|
|
input_lang.addSentence(pair[0])
|
|
output_lang.addSentence(pair[1])
|
|
print("Counted words:")
|
|
print(input_lang.name, input_lang.n_words)
|
|
print(output_lang.name, output_lang.n_words)
|
|
return input_lang, output_lang, pairs
|
|
|
|
|
|
input_lang, output_lang, pairs = prepareData('pl', 'eng', True)
|
|
print(random.choice(pairs))
|
|
with open('data/pairs.pkl', 'wb+') as p:
|
|
pickle.dump(pairs, p, protocol=pickle.HIGHEST_PROTOCOL)
|
|
with open('data/pl_lang.pkl', 'wb+') as p:
|
|
pickle.dump(input_lang, p, protocol=pickle.HIGHEST_PROTOCOL)
|
|
with open('data/en_lang.pkl', 'wb+') as p:
|
|
pickle.dump(output_lang, p, protocol=pickle.HIGHEST_PROTOCOL)
|