GRU attention with bpe
This commit is contained in:
parent
ba5d69080f
commit
eba47142be
448
gru_attention.ipynb
Normal file
448
gru_attention.ipynb
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
{
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 0,
|
||||||
|
"metadata": {
|
||||||
|
"colab": {
|
||||||
|
"name": "gru_attention.ipynb",
|
||||||
|
"provenance": [],
|
||||||
|
"collapsed_sections": []
|
||||||
|
},
|
||||||
|
"kernelspec": {
|
||||||
|
"name": "python3",
|
||||||
|
"display_name": "Python 3"
|
||||||
|
},
|
||||||
|
"accelerator": "GPU"
|
||||||
|
},
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "oinaxXuLWqvW"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"! pip install bpe"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "0x-5p_va6YMa"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"import torch \r\n",
|
||||||
|
"import re\r\n",
|
||||||
|
"import random\r\n",
|
||||||
|
"import pandas\r\n",
|
||||||
|
"import numpy\r\n",
|
||||||
|
"from torch.autograd import Variable\r\n",
|
||||||
|
"import torch.nn as nn\r\n",
|
||||||
|
"import time\r\n",
|
||||||
|
"import math\r\n",
|
||||||
|
"from torch import optim\r\n",
|
||||||
|
"import torch.nn.functional as F\r\n",
|
||||||
|
"from bpe import Encoder\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\r\n",
|
||||||
|
"SOS_token = 0\r\n",
|
||||||
|
"EOS_token = -1"
|
||||||
|
],
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "_FAYpub6Icy6"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"\r\n",
|
||||||
|
"class EncoderRNN(nn.Module):\r\n",
|
||||||
|
" def __init__(self, input_size, hidden_size):\r\n",
|
||||||
|
" super(EncoderRNN, self).__init__()\r\n",
|
||||||
|
" self.hidden_size = hidden_size\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" self.embedding = nn.Embedding(input_size, hidden_size)\r\n",
|
||||||
|
" self.gru = nn.GRU(hidden_size, hidden_size)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" def forward(self, input, hidden):\r\n",
|
||||||
|
" embedded = self.embedding(input).view(1, 1, -1)\r\n",
|
||||||
|
" output = embedded\r\n",
|
||||||
|
" output, hidden = self.gru(output, hidden)\r\n",
|
||||||
|
" return output, hidden\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" def initHidden(self):\r\n",
|
||||||
|
" return torch.zeros(1, 1, self.hidden_size, device=device)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"class DecoderRNN(nn.Module):\r\n",
|
||||||
|
" def __init__(self, hidden_size, output_size):\r\n",
|
||||||
|
" super(DecoderRNN, self).__init__()\r\n",
|
||||||
|
" self.hidden_size = hidden_size\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\r\n",
|
||||||
|
" self.gru = nn.GRU(hidden_size, hidden_size)\r\n",
|
||||||
|
" self.out = nn.Linear(hidden_size, output_size)\r\n",
|
||||||
|
" self.softmax = nn.LogSoftmax(dim=1)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" def forward(self, input, hidden):\r\n",
|
||||||
|
" output = self.embedding(input).view(1, 1, -1)\r\n",
|
||||||
|
" output = F.relu(output)\r\n",
|
||||||
|
" output, hidden = self.gru(output, hidden)\r\n",
|
||||||
|
" output = self.softmax(self.out(output[0]))\r\n",
|
||||||
|
" return output, hidden\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" def initHidden(self):\r\n",
|
||||||
|
" return torch.zeros(1, 1, self.hidden_size, device=device)"
|
||||||
|
],
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "Rf550t9_ec3g"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"vocab_size = 1500\r\n",
|
||||||
|
"bpe_encoder_pl = Encoder(vocab_size=vocab_size, pct_bpe=0.5)\r\n",
|
||||||
|
"bpe_encoder_en = Encoder(vocab_size=vocab_size, pct_bpe=0.5)"
|
||||||
|
],
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "uJObUcjOe_SM"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"MAX_LENGTH = 80\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def filter_pair(p):\r\n",
|
||||||
|
" return len(p[0]) < MAX_LENGTH and \\\r\n",
|
||||||
|
" len(p[1]) < MAX_LENGTH and \\\r\n",
|
||||||
|
" len(p[0]) > 0 and \\\r\n",
|
||||||
|
" len(p[1]) > 0\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def filter_pairs(pairs):\r\n",
|
||||||
|
" return [pair for pair in pairs if filter_pair(pair)]\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def normalize_string(s):\r\n",
|
||||||
|
" s = s.lower().strip()\r\n",
|
||||||
|
" s = re.sub(r\"([.!?~])\", r\" \\1\", s)\r\n",
|
||||||
|
" return s\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def sentence_to_codes(s, bpe_coder):\r\n",
|
||||||
|
" s = normalize_string(s)\r\n",
|
||||||
|
" #s += \" ___\"\r\n",
|
||||||
|
" c = next(bpe_coder.transform([s]))\r\n",
|
||||||
|
" #c.append(EOS_token)\r\n",
|
||||||
|
" return c\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def read_langs(in_f, exp_f, lines=150):\r\n",
|
||||||
|
" print(\"Reading lines...\")\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" # Read the file and split into lines\r\n",
|
||||||
|
" linesIn = open(in_f).read().strip().split('\\n')[:lines]\r\n",
|
||||||
|
" linesOut = open(exp_f).read().strip().split('\\n')[:lines]\r\n",
|
||||||
|
" #for i, (line_in, line_out) in enumerate(zip(linesIn, linesOut)):\r\n",
|
||||||
|
" # linesIn[i] += normalize_string(line_in) \r\n",
|
||||||
|
" # linesOut[i] += normalize_string(line_out) + \" ~\"\r\n",
|
||||||
|
" bpe_encoder_pl.fit(linesIn)\r\n",
|
||||||
|
" bpe_encoder_en.fit(linesOut)\r\n",
|
||||||
|
" # Split every line into pairs and normalize\r\n",
|
||||||
|
" pairs = [[sentence_to_codes(a, bpe_encoder_pl),sentence_to_codes(b, bpe_encoder_en)] for a,b in zip(linesIn,linesOut)]\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" pairs = filter_pairs(pairs)\r\n",
|
||||||
|
" print(\"Pairs created\")\r\n",
|
||||||
|
" return pairs"
|
||||||
|
],
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "POCQzFXTmnPx"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"code_pairs = read_langs('train/in.tsv', 'train/expected.tsv', 2500)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "VwsDASIQCyOz"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"#code_pairs[0]"
|
||||||
|
],
|
||||||
|
"execution_count": 18,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "ZmamapSQw1S5"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"#bpe_encoder_en.bpe_vocab"
|
||||||
|
],
|
||||||
|
"execution_count": 19,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "PVfbcwHenhEB"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"teacher_forcing_ratio = 0.95\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def train(input_tensor, target_tensor, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):\r\n",
|
||||||
|
" encoder_hidden = encoder.initHidden()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" encoder_optimizer.zero_grad()\r\n",
|
||||||
|
" decoder_optimizer.zero_grad()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" input_length = input_tensor.size(0)\r\n",
|
||||||
|
" target_length = target_tensor.size(0)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" loss = 0\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" for ei in range(input_length):\r\n",
|
||||||
|
" encoder_output, encoder_hidden = encoder(\r\n",
|
||||||
|
" input_tensor[ei], encoder_hidden)\r\n",
|
||||||
|
" encoder_outputs[ei] = encoder_output[0, 0]\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoder_input = torch.tensor([[SOS_token]], device=device)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoder_hidden = encoder_hidden\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" use_teacher_forcing = True if random.random() < teacher_forcing_ratio else False\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" last = 500\r\n",
|
||||||
|
" if use_teacher_forcing:\r\n",
|
||||||
|
" # Teacher forcing: Feed the target as the next input\r\n",
|
||||||
|
" for di in range(target_length):\r\n",
|
||||||
|
" decoder_output, decoder_hidden = decoder(\r\n",
|
||||||
|
" decoder_input, decoder_hidden)\r\n",
|
||||||
|
" loss += criterion(decoder_output, target_tensor[di])\r\n",
|
||||||
|
" decoder_input = target_tensor[di] # Teacher forcing\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" else:\r\n",
|
||||||
|
" # Without teacher forcing: use its own predictions as the next input\r\n",
|
||||||
|
" for di in range(target_length):\r\n",
|
||||||
|
" decoder_output, decoder_hidden = decoder(\r\n",
|
||||||
|
" decoder_input, decoder_hidden)\r\n",
|
||||||
|
" topv, topi = decoder_output.topk(1)\r\n",
|
||||||
|
" decoder_input = topi.squeeze().detach() # detach from history as input\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" loss += criterion(decoder_output, target_tensor[di])\r\n",
|
||||||
|
" #if decoder_input.item() == EOS_token:\r\n",
|
||||||
|
" # break\r\n",
|
||||||
|
" #print(loss)\r\n",
|
||||||
|
" try:\r\n",
|
||||||
|
" loss.backward()\r\n",
|
||||||
|
" except AttributeError:\r\n",
|
||||||
|
" print(f\"loss: {loss}\")\r\n",
|
||||||
|
" print(f\"input_tensor: {input_tensor}\")\r\n",
|
||||||
|
" print(f\"target_tensor: {target_tensor}\")\r\n",
|
||||||
|
" encoder_optimizer.step()\r\n",
|
||||||
|
" decoder_optimizer.step()\r\n",
|
||||||
|
" \r\n",
|
||||||
|
"\r\n",
|
||||||
|
" return loss.item() / target_length"
|
||||||
|
],
|
||||||
|
"execution_count": 20,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "nItV2ibhr7HA"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def list_to_tensor(l):\r\n",
|
||||||
|
" return torch.tensor(l, dtype=torch.long, device=device).view(-1, 1)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def pairs_to_tensor(pair):\r\n",
|
||||||
|
" in_tensor = list_to_tensor(pair[0])\r\n",
|
||||||
|
" out_tensor = list_to_tensor(pair[1])\r\n",
|
||||||
|
" return (in_tensor, out_tensor)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def asMinutes(s):\r\n",
|
||||||
|
" m = math.floor(s / 60)\r\n",
|
||||||
|
" s -= m * 60\r\n",
|
||||||
|
" return '%dm %ds' % (m, s)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def timeSince(since, percent):\r\n",
|
||||||
|
" now = time.time()\r\n",
|
||||||
|
" s = now - since\r\n",
|
||||||
|
" es = s / (percent)\r\n",
|
||||||
|
" rs = es - s\r\n",
|
||||||
|
" return '%s (- %s)' % (asMinutes(s), asMinutes(rs))\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def trainIters(encoder, decoder, n_iters, print_every=1000, learning_rate=0.01):\r\n",
|
||||||
|
" start = time.time()\r\n",
|
||||||
|
" plot_losses = []\r\n",
|
||||||
|
" print_loss_total = 0 # Reset every print_every\r\n",
|
||||||
|
" plot_loss_total = 0 # Reset every plot_every\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" encoder_optimizer = optim.SGD(encoder.parameters(), lr=learning_rate)\r\n",
|
||||||
|
" decoder_optimizer = optim.SGD(decoder.parameters(), lr=learning_rate)\r\n",
|
||||||
|
" training_pairs = [pairs_to_tensor(random.choice(code_pairs))\r\n",
|
||||||
|
" for i in range(n_iters)]\r\n",
|
||||||
|
" criterion = nn.NLLLoss()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" for iter in range(1, n_iters + 1):\r\n",
|
||||||
|
" training_pair = training_pairs[iter - 1]\r\n",
|
||||||
|
" input_tensor = training_pair[0]\r\n",
|
||||||
|
" target_tensor = training_pair[1]\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" loss = train(input_tensor, target_tensor, encoder,\r\n",
|
||||||
|
" decoder, encoder_optimizer, decoder_optimizer, criterion)\r\n",
|
||||||
|
" print_loss_total += loss\r\n",
|
||||||
|
" plot_loss_total += loss\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" if iter % print_every == 0:\r\n",
|
||||||
|
" print_loss_avg = print_loss_total / print_every\r\n",
|
||||||
|
" print_loss_total = 0\r\n",
|
||||||
|
" print('%s (%d %d%%) %.4f' % (timeSince(start, iter / n_iters),\r\n",
|
||||||
|
" iter, iter / n_iters * 100, print_loss_avg))"
|
||||||
|
],
|
||||||
|
"execution_count": 21,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "MRuYGa9nzOi9"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"hidden_size = 256\r\n",
|
||||||
|
"encoder1 = EncoderRNN(vocab_size, hidden_size).to(device)\r\n",
|
||||||
|
"decoder1 = DecoderRNN(hidden_size, vocab_size).to(device)"
|
||||||
|
],
|
||||||
|
"execution_count": 22,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "LgBYFrTt5Go6"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"trainIters(encoder1, decoder1, 35000, print_every=5)"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "QvoQeb1lzuaB"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"def evaluate(encoder, decoder, sentence, max_length=MAX_LENGTH):\r\n",
|
||||||
|
" with torch.no_grad():\r\n",
|
||||||
|
" #a = sentence_to_codes(sentence, bpe_encoder_pl)\r\n",
|
||||||
|
" #input_tensor = tensorFromSentence(input_lang, sentence)\r\n",
|
||||||
|
" input_tensor = list_to_tensor(sentence_to_codes(sentence, bpe_encoder_pl))\r\n",
|
||||||
|
" input_length = input_tensor.size()[0]\r\n",
|
||||||
|
" encoder_hidden = encoder.initHidden()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" encoder_outputs = torch.zeros(max_length, encoder.hidden_size, device=device)\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" for ei in range(input_length):\r\n",
|
||||||
|
" encoder_output, encoder_hidden = encoder(input_tensor[ei],\r\n",
|
||||||
|
" encoder_hidden)\r\n",
|
||||||
|
" encoder_outputs[ei] += encoder_output[0, 0]\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoder_input = torch.tensor([[SOS_token]], device=device) # SOS\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoder_hidden = encoder_hidden\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoded_words = []\r\n",
|
||||||
|
" eow_token = 501\r\n",
|
||||||
|
" last_word = -1\r\n",
|
||||||
|
" for di in range(max_length):\r\n",
|
||||||
|
" decoder_output, decoder_hidden = decoder(\r\n",
|
||||||
|
" decoder_input, decoder_hidden)\r\n",
|
||||||
|
" topv, topi = decoder_output.data.topk(1)\r\n",
|
||||||
|
" if topi.item() == last_word and topi.item() == eow_token:\r\n",
|
||||||
|
" # decoded_words.append('<EOS>')\r\n",
|
||||||
|
" break\r\n",
|
||||||
|
" else:\r\n",
|
||||||
|
" decoded_words.append(topi.item())\r\n",
|
||||||
|
" last_word = topi.item()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoder_input = topi.squeeze().detach()\r\n",
|
||||||
|
"\r\n",
|
||||||
|
" decoded_tokens = bpe_encoder_en.inverse_transform([decoded_words])\r\n",
|
||||||
|
" return decoded_tokens\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"def evaluateAndShow(input_sentence):\r\n",
|
||||||
|
" output_words = evaluate(\r\n",
|
||||||
|
" encoder1, decoder1, input_sentence)\r\n",
|
||||||
|
" return next(output_words)\r\n"
|
||||||
|
],
|
||||||
|
"execution_count": 14,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "2Y8rj7BhIpBS"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"\r\n",
|
||||||
|
"temp = open('test-A/in.tsv', 'r').readlines()\r\n",
|
||||||
|
"data = []\r\n",
|
||||||
|
"for sent in temp:\r\n",
|
||||||
|
" data.append(sent.replace('\\n',''))\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"f=open('test-A/out.tsv','w')\r\n",
|
||||||
|
"for sent in data:\r\n",
|
||||||
|
" f.write(evaluateAndShow(sent) + '\\n')\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"f.close()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"metadata": {
|
||||||
|
"id": "13Nm_jnCIpnl"
|
||||||
|
},
|
||||||
|
"source": [
|
||||||
|
"temp = open('dev-0/in.tsv', 'r').readlines()\r\n",
|
||||||
|
"data = []\r\n",
|
||||||
|
"for sent in temp:\r\n",
|
||||||
|
" data.append(sent.replace('\\n',''))\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"f=open('dev-0/out.tsv','w')\r\n",
|
||||||
|
"for sent in data:\r\n",
|
||||||
|
" f.write(evaluateAndShow(sent) + '\\n')\r\n",
|
||||||
|
"\r\n",
|
||||||
|
"f.close()"
|
||||||
|
],
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
1200
test-A/out.tsv
Normal file
1200
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user