from __future__ import unicode_literals, print_function, division from io import open import unicodedata import string import re import random import pickle import torch import torch.nn as nn from torch import optim import torch.nn.functional as F from lang import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MAX_LENGTH = 300 # Turn a Unicode string to plain ASCII, thanks to # https://stackoverflow.com/a/518232/2809427 def unicodeToAscii(s): return ''.join( c for c in unicodedata.normalize('NFD', s) if unicodedata.category(c) != 'Mn' ) # Lowercase, trim, and remove non-letter characters def filterPair(p): return len(tokenizer.tokenize(p[0])) < MAX_LENGTH and \ len(tokenizer.tokenize(p[1])) < MAX_LENGTH def filterPairs(pairs): return [pair for pair in pairs if filterPair(pair)] def normalizeString(s): s = re.sub(r"([.!?])", r" \1", s) return s def readLangs(lang1, lang2, reverse=False): print("Reading lines...") lines_pl = [] lines_en = [] # Read the file and split into lines with open('train/in.tsv', 'r', encoding='utf-8') as pl: for line in pl: line = line.rstrip() lines_pl.append(line) with open('train/expected.tsv', 'r', encoding='utf-8') as en: for line in en: line = line.rstrip() lines_en.append(line) # Split every line into pairs and normalize pairs = [] for p, e in zip(lines_pl, lines_en): pl_s = normalizeString(p) pl_e = normalizeString(e) pairs.append([pl_e, pl_s]) if reverse: pairs = [list(reversed(p)) for p in pairs] input_lang = Lang(lang2) output_lang = Lang(lang1) else: input_lang = Lang(lang1) output_lang = Lang(lang2) return input_lang, output_lang, pairs def prepareData(lang1, lang2, reverse=False): input_lang, output_lang, pairs = readLangs(lang1, lang2, reverse) print("Read %s sentence pairs" % len(pairs)) pairs = filterPairs(pairs) print("Trimmed to %s sentence pairs" % len(pairs)) print("Counting words...") for pair in pairs: input_lang.addSentence(pair[0]) output_lang.addSentence(pair[1]) print("Counted words:") print(input_lang.name, input_lang.n_words) print(output_lang.name, output_lang.n_words) return input_lang, output_lang, pairs input_lang, output_lang, pairs = prepareData('pl', 'eng', True) print(random.choice(pairs)) with open('data/pairs.pkl', 'wb+') as p: pickle.dump(pairs, p, protocol=pickle.HIGHEST_PROTOCOL) with open('data/pl_lang.pkl', 'wb+') as p: pickle.dump(input_lang, p, protocol=pickle.HIGHEST_PROTOCOL) with open('data/en_lang.pkl', 'wb+') as p: pickle.dump(output_lang, p, protocol=pickle.HIGHEST_PROTOCOL)