Europarl/TAU_translator_from_scratch.ipynb

1 line
45 KiB
Plaintext
Raw Permalink Normal View History

2020-01-28 18:51:11 +01:00
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"name":"TAU_translator_from_scratch.ipynb","provenance":[],"collapsed_sections":["RF1vdsADCAM1"],"toc_visible":true,"authorship_tag":"ABX9TyMRnEKpWjs7cheHuy93zBBW"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{"id":"EPqzHV7BucP2","colab_type":"text"},"source":["Install\n","---\n","#### TO DO\n","* Load and prepare data for train, dev-0 and test-A sets\n","* Prepere basic Encoder adn Decoder\n","* Test training basic Encoder-Decoder + Adam optimzer + Cross Entropy loss fnc\n","* Report model training status (epoch, time, iterations, current loss)\n","* Model saving to drive, saving evaluation on test-A (and evaluation on dev-0?) on the end of epoch \n","* Add pretrained embeddings\n","* Add attention mechanism\n","* Reverse input sentence?\n","* BiRNN?"]},{"cell_type":"code","metadata":{"id":"Bx7TQD2DsDg7","colab_type":"code","colab":{}},"source":["from google.colab import drive\n","drive.mount('/content/drive')"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"c3NR3DOisZoQ","colab_type":"code","colab":{}},"source":["!pip install torch torchvision"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"id":"BRcLW_KWxD2I","colab_type":"code","colab":{}},"source":["import re\n","import math\n","import time\n","import os.path\n","\n","import numpy as np\n","import torch\n","import torch.nn as nn\n","import torch.nn.functional as F\n","from torch.utils.data import Dataset, DataLoader\n","from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n","from torch import optim\n","from tqdm import tqdm"],"execution_count":0,"outputs":[]},{"cell_type":"code","metadata":{"colab_type":"code","id":"a6r74YL2-8Jm","outputId":"346ccb2b-6b19-4628-caa1-34b8318a9ca1","executionInfo":{"status":"ok","timestamp":1580115861850,"user_tz":-60,"elapsed":2527,"user":{"displayName":"Stanisław Gołębiewski","photoUrl":"","userId":"02205040307954405899"}},"colab":{"base_uri":"https://localhost:8080/","height":34}},"source":["device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n","print(device)"],"execution_count":0,"outputs":[{"output_type":"stream","text":["cuda:0\n"],"name":"stdout"}]},{"cell_type":"markdown","metadata":{"id":"hDsPNpPN-EeV","colab_type":"text"},"source":["### Load data"]},{"cell_type":"code","metadata":{"id":"ZujLKZUKsvS8","colab_type":"code","colab":{}},"source":["BLANK_TOKEN = 0\n","SOS_TOKEN = 1\n","EOS_TOKEN = 2\n","UNK_TOKEN = 3\n","\n","class Lang:\n"," # https://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html\n","\n"," def __init__(self, name):\n"," self.name = name\n"," self.word2index = {\"<BLANK>\": 0, \"<SOS>\": 1, \"<EOS>\": 2, \"<UNK>\": 3}\n"," self.word2count = {}\n"," self.index2word = {0: \"<BLANK>\", 1: \"<SOS>\", 2: \"<EOS>\", 3: \"<UNK>\"}\n"," self.n_words = 4\n","\n"," def addSentence(self, sentence):\n"," for word in sentence.split(' '):\n"," self.addWord(word)\n","\n"," def addWord(self, word):\n"," if word not in self.word2index:\n"," self.word2index[word] = self.n_words\n"," self.word2count[word] = 1\n"," self.index2word[self.n_words] = word\n"," self.n_words += 1\n"," else:\n"," self.word2count[word] += 1\n","\n"," def vector2Sentence(self, vector):\n"," words = []\n"," for index in vector:\n"," if index not in self.index2word:\n"," words.append(\"<UNK>\")\n"," else:\n"," words.append(self.index2word[index])\n"," return \" \".join(words)\n","\n"," # def blankAllLowFreqWorlds(self):\n"," # new_word2index = {\"SOS\": 0, \"EOS\": 1, \"<UNK>\": 2}\n"," # new_index2word = {0: \"SOS\", 1: \"EOS\", 2: \"<UNK>\"}\n"," # new_n_words = 3\n","\n"," # for word in self.word2count:\n"," # if self.word2count[word] == 1:\n"," # new