GRU attention with bpe

This commit is contained in:
SzamanFL 2021-02-11 19:13:05 +01:00
parent ba5d69080f
commit eba47142be
2 changed files with 1648 additions and 0 deletions

448
gru_attention.ipynb Normal file
View File

@ -0,0 +1,448 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "gru_attention.ipynb",
"provenance": [],
"collapsed_sections": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"metadata": {
"id": "oinaxXuLWqvW"
},
"source": [
"! pip install bpe"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "0x-5p_va6YMa"
},
"source": [
"import torch \r\n",
"import re\r\n",
"import random\r\n",
"import pandas\r\n",
"import numpy\r\n",
"from torch.autograd import Variable\r\n",
"import torch.nn as nn\r\n",
"import time\r\n",
"import math\r\n",
"from torch import optim\r\n",
"import torch.nn.functional as F\r\n",
"from bpe import Encoder\r\n",
"\r\n",
"\r\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n",
"SOS_token = 0\r\n",
"EOS_token = -1"
],
"execution_count": 5,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "_FAYpub6Icy6"
},
"source": [
"\r\n",
"class EncoderRNN(nn.Module):\r\n",
" def __init__(self, input_size, hidden_size):\r\n",
" super(EncoderRNN, self).__init__()\r\n",
" self.hidden_size = hidden_size\r\n",
"\r\n",
" self.embedding = nn.Embedding(input_size, hidden_size)\r\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\r\n",
"\r\n",
" def forward(self, input, hidden):\r\n",
" embedded = self.embedding(input).view(1, 1, -1)\r\n",
" output = embedded\r\n",
" output, hidden = self.gru(output, hidden)\r\n",
" return output, hidden\r\n",
"\r\n",
" def initHidden(self):\r\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)\r\n",
"\r\n",
"class DecoderRNN(nn.Module):\r\n",
" def __init__(self, hidden_size, output_size):\r\n",
" super(DecoderRNN, self).__init__()\r\n",
" self.hidden_size = hidden_size\r\n",
"\r\n",
" self.embedding = nn.Embedding(output_size, hidden_size)\r\n",
" self.gru = nn.GRU(hidden_size, hidden_size)\r\n",
" self.out = nn.Linear(hidden_size, output_size)\r\n",
" self.softmax = nn.LogSoftmax(dim=1)\r\n",
"\r\n",
" def forward(self, input, hidden):\r\n",
" output = self.embedding(input).view(1, 1, -1)\r\n",
" output = F.relu(output)\r\n",
" output, hidden = self.gru(output, hidden)\r\n",
" output = self.softmax(self.out(output[0]))\r\n",
" return output, hidden\r\n",
"\r\n",
" def initHidden(self):\r\n",
" return torch.zeros(1, 1, self.hidden_size, device=device)"
],
"execution_count": 6,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Rf550t9_ec3g"
},
"source": [
"vocab_size = 1500\r\n",
"bpe_encoder_pl = Encoder(vocab_size=vocab_size, pct_bpe=0.5)\r\n",
"bpe_encoder_en = Encoder(vocab_size=vocab_size, pct_bpe=0.5)"
],
"execution_count": 7,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "uJObUcjOe_SM"
},
"source": [
"MAX_LENGTH = 80\r\n",
"\r\n",
"\r\n",
"def filter_pair(p):\r\n",
" return len(p[0]) < MAX_LENGTH and \\\r\n",
" len(p[1]) < MAX_LENGTH and \\\r\n",
" len(p[0]) > 0 and \\\r\n",
" len(p[1]) > 0\r\n",
"\r\n",
"\r\n",
"def filter_pairs(pairs):\r\n",
" return [pair for pair in pairs if filter_pair(pair)]\r\n",
"\r\n",
"\r\n",
"def normalize_string(s):\r\n",
" s = s.lower().strip()\r\n",
" s = re.sub(r\"([.!?~])\", r\" \\1\", s)\r\n",
" return s\r\n",
"\r\n",
"\r\n",
"def sentence_to_codes(s, bpe_coder):\r\n",
" s = normalize_string(s)\r\n",
" #s += \" ___\"\r\n",
" c = next(bpe_coder.transform([s]))\r\n",
" #c.append(EOS_token)\r\n",
" return c\r\n",
"\r\n",
"\r\n",
"def read_langs(in_f, exp_f, lines=150):\r\n",
" print(\"Reading lines...\")\r\n",
"\r\n",
" # Read the file and split into lines\r\n",
" linesIn = open(in_f).read().strip().split('\\n')[:lines]\r\n",
" linesOut = open(exp_f).read().strip().split('\\n')[:lines]\r\n",
" #for i, (line_in, line_out) in enumerate(zip(linesIn, linesOut)):\r\n",
" # linesIn[i] += normalize_string(line_in) \r\n",
" # linesOut[i] += normalize_string(line_out) + \" ~\"\r\n",
" bpe_encoder_pl.fit(linesIn)\r\n",
" bpe_encoder_en.fit(linesOut)\r\n",
" # Split every line into pairs and normalize\r\n",
" pairs = [[sentence_to_codes(a, bpe_encoder_pl),sentence_to_codes(b, bpe_encoder_en)] for a,b in zip(linesIn,linesOut)]\r\n",
"\r\n",
" pairs = filter_pairs(pairs)\r\n",
" print(\"Pairs created\")\r\n",
" return pairs"
],
"execution_count": 8,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "POCQzFXTmnPx"
},
"source": [
"code_pairs = read_langs('train/in.tsv', 'train/expected.tsv', 2500)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "VwsDASIQCyOz"
},
"source": [
"#code_pairs[0]"
],
"execution_count": 18,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "ZmamapSQw1S5"
},
"source": [
"#bpe_encoder_en.bpe_vocab"
],
"execution_count": 19,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "PVfbcwHenhEB"
},
"source": [
"teacher_forcing_ratio = 0.95\r\n",
"\r\n",
"def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\r\n",
" encoder_hidden = encoder.initHidden()\r\n",
"\r\n",
" encoder_optimizer.zero_grad()\r\n",
" decoder_optimizer.zero_grad()\r\n",
"\r\n",
" input_length = input_tensor.size(0)\r\n",
" target_length = target_tensor.size(0)\r\n",
"\r\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
"\r\n",
" loss = 0\r\n",
"\r\n",
" for ei in range(input_length):\r\n",
" encoder_output, encoder_hidden = encoder(\r\n",
" input_tensor[ei], encoder_hidden)\r\n",
" encoder_outputs[ei] = encoder_output[0, 0]\r\n",
"\r\n",
" decoder_input = torch.tensor([[SOS_token]], device=device)\r\n",
"\r\n",
" decoder_hidden = encoder_hidden\r\n",
"\r\n",
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\r\n",
"\r\n",
" last = 500\r\n",
" if use_teacher_forcing:\r\n",
" # Teacher forcing: Feed the target as the next input\r\n",
" for di in range(target_length):\r\n",
" decoder_output, decoder_hidden = decoder(\r\n",
" decoder_input, decoder_hidden)\r\n",
" loss += criterion(decoder_output, target_tensor[di])\r\n",
" decoder_input = target_tensor[di] # Teacher forcing\r\n",
"\r\n",
" else:\r\n",
" # Without teacher forcing: use its own predictions as the next input\r\n",
" for di in range(target_length):\r\n",
" decoder_output, decoder_hidden = decoder(\r\n",
" decoder_input, decoder_hidden)\r\n",
" topv, topi = decoder_output.topk(1)\r\n",
" decoder_input = topi.squeeze().detach() # detach from history as input\r\n",
"\r\n",
" loss += criterion(decoder_output, target_tensor[di])\r\n",
" #if decoder_input.item() == EOS_token:\r\n",
" # break\r\n",
" #print(loss)\r\n",
" try:\r\n",
" loss.backward()\r\n",
" except AttributeError:\r\n",
" print(f\"loss: {loss}\")\r\n",
" print(f\"input_tensor: {input_tensor}\")\r\n",
" print(f\"target_tensor: {target_tensor}\")\r\n",
" encoder_optimizer.step()\r\n",
" decoder_optimizer.step()\r\n",
" \r\n",
"\r\n",
" return loss.item() / target_length"
],
"execution_count": 20,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "nItV2ibhr7HA"
},
"source": [
"def list_to_tensor(l):\r\n",
" return torch.tensor(l, dtype=torch.long, device=device).view(-1, 1)\r\n",
"\r\n",
"def pairs_to_tensor(pair):\r\n",
" in_tensor = list_to_tensor(pair[0])\r\n",
" out_tensor = list_to_tensor(pair[1])\r\n",
" return (in_tensor, out_tensor)\r\n",
"\r\n",
"\r\n",
"def asMinutes(s):\r\n",
" m = math.floor(s / 60)\r\n",
" s -= m * 60\r\n",
" return '%dm %ds' % (m, s)\r\n",
"\r\n",
"\r\n",
"def timeSince(since, percent):\r\n",
" now = time.time()\r\n",
" s = now - since\r\n",
" es = s / (percent)\r\n",
" rs = es - s\r\n",
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))\r\n",
"\r\n",
"def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\r\n",
" start = time.time()\r\n",
" plot_losses = []\r\n",
" print_loss_total = 0 # Reset every print_every\r\n",
" plot_loss_total = 0 # Reset every plot_every\r\n",
"\r\n",
" encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\r\n",
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\r\n",
" training_pairs = [pairs_to_tensor(random.choice(code_pairs))\r\n",
" for i in range(n_iters)]\r\n",
" criterion = nn.NLLLoss()\r\n",
"\r\n",
" for iter in range(1, n_iters + 1):\r\n",
" training_pair = training_pairs[iter - 1]\r\n",
" input_tensor = training_pair[0]\r\n",
" target_tensor = training_pair[1]\r\n",
"\r\n",
" loss = train(input_tensor, target_tensor, encoder,\r\n",
" decoder, encoder_optimizer, decoder_optimizer, criterion)\r\n",
" print_loss_total += loss\r\n",
" plot_loss_total += loss\r\n",
"\r\n",
" if iter % print_every == 0:\r\n",
" print_loss_avg = print_loss_total / print_every\r\n",
" print_loss_total = 0\r\n",
" print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\r\n",
" iter, iter / n_iters * 100, print_loss_avg))"
],
"execution_count": 21,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "MRuYGa9nzOi9"
},
"source": [
"hidden_size = 256\r\n",
"encoder1 = EncoderRNN(vocab_size, hidden_size).to(device)\r\n",
"decoder1 = DecoderRNN(hidden_size, vocab_size).to(device)"
],
"execution_count": 22,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "LgBYFrTt5Go6"
},
"source": [
"trainIters(encoder1, decoder1, 35000, print_every=5)"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "QvoQeb1lzuaB"
},
"source": [
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\r\n",
" with torch.no_grad():\r\n",
" #a = sentence_to_codes(sentence, bpe_encoder_pl)\r\n",
" #input_tensor = tensorFromSentence(input_lang, sentence)\r\n",
" input_tensor = list_to_tensor(sentence_to_codes(sentence, bpe_encoder_pl))\r\n",
" input_length = input_tensor.size()[0]\r\n",
" encoder_hidden = encoder.initHidden()\r\n",
"\r\n",
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
"\r\n",
" for ei in range(input_length):\r\n",
" encoder_output, encoder_hidden = encoder(input_tensor[ei],\r\n",
" encoder_hidden)\r\n",
" encoder_outputs[ei] += encoder_output[0, 0]\r\n",
"\r\n",
" decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\r\n",
"\r\n",
" decoder_hidden = encoder_hidden\r\n",
"\r\n",
" decoded_words = []\r\n",
" eow_token = 501\r\n",
" last_word = -1\r\n",
" for di in range(max_length):\r\n",
" decoder_output, decoder_hidden = decoder(\r\n",
" decoder_input, decoder_hidden)\r\n",
" topv, topi = decoder_output.data.topk(1)\r\n",
" if topi.item() == last_word and topi.item() == eow_token:\r\n",
" # decoded_words.append('<EOS>')\r\n",
" break\r\n",
" else:\r\n",
" decoded_words.append(topi.item())\r\n",
" last_word = topi.item()\r\n",
"\r\n",
" decoder_input = topi.squeeze().detach()\r\n",
"\r\n",
" decoded_tokens = bpe_encoder_en.inverse_transform([decoded_words])\r\n",
" return decoded_tokens\r\n",
"\r\n",
"def evaluateAndShow(input_sentence):\r\n",
" output_words = evaluate(\r\n",
" encoder1, decoder1, input_sentence)\r\n",
" return next(output_words)\r\n"
],
"execution_count": 14,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "2Y8rj7BhIpBS"
},
"source": [
"\r\n",
"temp = open('test-A/in.tsv', 'r').readlines()\r\n",
"data = []\r\n",
"for sent in temp:\r\n",
" data.append(sent.replace('\\n',''))\r\n",
"\r\n",
"f=open('test-A/out.tsv','w')\r\n",
"for sent in data:\r\n",
" f.write(evaluateAndShow(sent) + '\\n')\r\n",
"\r\n",
"f.close()"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "13Nm_jnCIpnl"
},
"source": [
"temp = open('dev-0/in.tsv', 'r').readlines()\r\n",
"data = []\r\n",
"for sent in temp:\r\n",
" data.append(sent.replace('\\n',''))\r\n",
"\r\n",
"f=open('dev-0/out.tsv','w')\r\n",
"for sent in data:\r\n",
" f.write(evaluateAndShow(sent) + '\\n')\r\n",
"\r\n",
"f.close()"
],
"execution_count": null,
"outputs": []
}
]
}

1200
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff