wmt-2017-cs-en/prepare.py
2021-01-27 04:01:04 +01:00

91 lines
2.8 KiB
Python

from __future__ import unicode_literals, print_function, division
from io import open
import unicodedata
import string
import re
import random
import pickle
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from lang import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MAX_LENGTH = 300
# Turn a Unicode string to plain ASCII, thanks to
# https://stackoverflow.com/a/518232/2809427
def unicodeToAscii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
)
# Lowercase, trim, and remove non-letter characters
def filterPair(p):
return len(tokenizer.tokenize(p[0])) < MAX_LENGTH and \
len(tokenizer.tokenize(p[1])) < MAX_LENGTH
def filterPairs(pairs):
return [pair for pair in pairs if filterPair(pair)]
def normalizeString(s):
s = re.sub(r"([.!?])", r" \1", s)
return s
def readLangs(lang1, lang2, reverse=False):
print("Reading lines...")
lines_pl = []
lines_en = []
# Read the file and split into lines
with open('train/in.tsv', 'r', encoding='utf-8') as pl:
for line in pl:
line = line.rstrip()
lines_pl.append(line)
with open('train/expected.tsv', 'r', encoding='utf-8') as en:
for line in en:
line = line.rstrip()
lines_en.append(line)
# Split every line into pairs and normalize
pairs = []
for p, e in zip(lines_pl, lines_en):
pl_s = normalizeString(p)
pl_e = normalizeString(e)
pairs.append([pl_e, pl_s])
if reverse:
pairs = [list(reversed(p)) for p in pairs]
input_lang = Lang(lang2)
output_lang = Lang(lang1)
else:
input_lang = Lang(lang1)
output_lang = Lang(lang2)
return input_lang, output_lang, pairs
def prepareData(lang1, lang2, reverse=False):
input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse)
print("Read %s sentence pairs" % len(pairs))
pairs = filterPairs(pairs)
print("Trimmed to %s sentence pairs" % len(pairs))
print("Counting words...")
for pair in pairs:
input_lang.addSentence(pair[0])
output_lang.addSentence(pair[1])
print("Counted words:")
print(input_lang.name, input_lang.n_words)
print(output_lang.name, output_lang.n_words)
return input_lang, output_lang, pairs
input_lang, output_lang, pairs = prepareData('pl', 'eng', True)
print(random.choice(pairs))
with open('data/pairs.pkl', 'wb+') as p:
pickle.dump(pairs, p, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/pl_lang.pkl', 'wb+') as p:
pickle.dump(input_lang, p, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/en_lang.pkl', 'wb+') as p:
pickle.dump(output_lang, p, protocol=pickle.HIGHEST_PROTOCOL)