Initial commit
This commit is contained in:
parent
2c3f06239c
commit
5c591315b5
916
seq2seq.ipynb
Normal file
916
seq2seq.ipynb
Normal file
@ -0,0 +1,916 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## Seq2Seq translation"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "9748a77fd2a3cac3"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 3,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import pandas as pd\n",
|
||||||
|
"import numpy as np\n",
|
||||||
|
"\n",
|
||||||
|
"# Pytorch\n",
|
||||||
|
"import torch\n",
|
||||||
|
"import torch.nn as nn\n",
|
||||||
|
"import torch.optim as optim\n",
|
||||||
|
"\n",
|
||||||
|
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler\n",
|
||||||
|
"import torch.nn.functional as F\n",
|
||||||
|
"\n",
|
||||||
|
"import torchtext\n",
|
||||||
|
"torchtext.disable_torchtext_deprecation_warning()\n",
|
||||||
|
"\n",
|
||||||
|
"from torchtext.data.metrics import bleu_score\n",
|
||||||
|
"\n",
|
||||||
|
"from unidecode import unidecode\n",
|
||||||
|
"\n",
|
||||||
|
"import regex as re\n",
|
||||||
|
"from string import punctuation\n",
|
||||||
|
"import random\n",
|
||||||
|
"\n",
|
||||||
|
"from tqdm.notebook import tqdm"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:51:30.799418700Z",
|
||||||
|
"start_time": "2024-05-23T07:51:30.786186900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "initial_id"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 4,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Use GPU if available\n",
|
||||||
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:51:31.808275400Z",
|
||||||
|
"start_time": "2024-05-23T07:51:31.791795200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "373fcd790697fce0"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Load data corpus"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "81adf3bad0a07802"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 5,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Load the data\n",
|
||||||
|
"df = pd.read_csv('pol-eng/pol.txt', sep='\\t', header=None, usecols=[0, 1], names=['source', 'target'])\n",
|
||||||
|
"df_1000 = df.head(1000)\n",
|
||||||
|
"df = df_1000"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:51:33.322595100Z",
|
||||||
|
"start_time": "2024-05-23T07:51:33.202326400Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "695222e59fb7c9fd"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": " source target\n0 Go. Idź.\n1 Hi. Cześć.\n2 Run! Uciekaj!\n3 Run. Biegnij.\n4 Run. Uciekaj.\n.. ... ...\n995 We walked. Poszliśmy pieszo.\n996 We yawned. Ziewaliśmy.\n997 We'll see. Zobaczymy.\n998 We're men. Jesteśmy mężczyznami.\n999 We're sad. Jesteśmy smutni.\n\n[1000 rows x 2 columns]",
|
||||||
|
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>source</th>\n <th>target</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td>Go.</td>\n <td>Idź.</td>\n </tr>\n <tr>\n <th>1</th>\n <td>Hi.</td>\n <td>Cześć.</td>\n </tr>\n <tr>\n <th>2</th>\n <td>Run!</td>\n <td>Uciekaj!</td>\n </tr>\n <tr>\n <th>3</th>\n <td>Run.</td>\n <td>Biegnij.</td>\n </tr>\n <tr>\n <th>4</th>\n <td>Run.</td>\n <td>Uciekaj.</td>\n </tr>\n <tr>\n <th>...</th>\n <td>...</td>\n <td>...</td>\n </tr>\n <tr>\n <th>995</th>\n <td>We walked.</td>\n <td>Poszliśmy pieszo.</td>\n </tr>\n <tr>\n <th>996</th>\n <td>We yawned.</td>\n <td>Ziewaliśmy.</td>\n </tr>\n <tr>\n <th>997</th>\n <td>We'll see.</td>\n <td>Zobaczymy.</td>\n </tr>\n <tr>\n <th>998</th>\n <td>We're men.</td>\n <td>Jesteśmy mężczyznami.</td>\n </tr>\n <tr>\n <th>999</th>\n <td>We're sad.</td>\n <td>Jesteśmy smutni.</td>\n </tr>\n </tbody>\n</table>\n<p>1000 rows × 2 columns</p>\n</div>"
|
||||||
|
},
|
||||||
|
"execution_count": 6,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"df"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:51:34.389440400Z",
|
||||||
|
"start_time": "2024-05-23T07:51:34.365627300Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "e2804a7cbe464e9a"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Initial data preprocessing"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "b583da4ce157173c"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Helper class to create language vocabularies\n",
|
||||||
|
"# <bos> - beginning of sentence token - 0\n",
|
||||||
|
"# <eos> - end of sentence token - 1\n",
|
||||||
|
"bos_token = 0\n",
|
||||||
|
"eos_token = 1\n",
|
||||||
|
"\n",
|
||||||
|
"class Lang:\n",
|
||||||
|
" def __init__(self, name):\n",
|
||||||
|
" self.name = name\n",
|
||||||
|
" self.word2index = {'<bos>': 0, '<eos>': 1}\n",
|
||||||
|
" self.word2count = {}\n",
|
||||||
|
" self.index2word = {0: '<bos>', 1: '<eos>'}\n",
|
||||||
|
" self.n_words = 2 # Count <bos> and <eos>\n",
|
||||||
|
" \n",
|
||||||
|
" def add_sentence(self, sentence):\n",
|
||||||
|
" for word in sentence.split(' '):\n",
|
||||||
|
" self.add_word(word)\n",
|
||||||
|
" \n",
|
||||||
|
" def add_word(self, word):\n",
|
||||||
|
" if word not in self.word2index:\n",
|
||||||
|
" self.word2index[word] = self.n_words\n",
|
||||||
|
" self.word2count[word] = 1\n",
|
||||||
|
" self.index2word[self.n_words] = word\n",
|
||||||
|
" self.n_words += 1\n",
|
||||||
|
" else:\n",
|
||||||
|
" self.word2count[word] += 1"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:52:08.652414900Z",
|
||||||
|
"start_time": "2024-05-23T07:52:08.645734700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "df1cce85da86dd9b"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 8,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Helper methods to preprocess data\n",
|
||||||
|
"def normalizeText(text, ascii: bool = False):\n",
|
||||||
|
" # Convert to ASCII\n",
|
||||||
|
" if ascii:\n",
|
||||||
|
" text = unidecode(text)\n",
|
||||||
|
" \n",
|
||||||
|
" # Lowercase and trim whitespace\n",
|
||||||
|
" text = text.lower().strip()\n",
|
||||||
|
" \n",
|
||||||
|
" # Remove non-letter characters\n",
|
||||||
|
" text = re.sub(r\"[\" + punctuation + \"]\", \"\", text)\n",
|
||||||
|
" \n",
|
||||||
|
" return text"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:52:09.334329900Z",
|
||||||
|
"start_time": "2024-05-23T07:52:09.325331900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "e9a7467f36ffb00b"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 9,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Method for data preparation (vocabularies, pairs of sentences)\n",
|
||||||
|
"def prepareData(df, source_lang, target_lang, ascii: bool = False):\n",
|
||||||
|
" # Normalize source and target sentences\n",
|
||||||
|
" df['source'] = df['source'].apply(lambda x: normalizeText(x, ascii=ascii))\n",
|
||||||
|
" df['target'] = df['target'].apply(lambda x: normalizeText(x, ascii=ascii))\n",
|
||||||
|
" \n",
|
||||||
|
" # Get pairs of sentences\n",
|
||||||
|
" pairs = list(zip(df['source'], df['target']))\n",
|
||||||
|
" \n",
|
||||||
|
" # Create language vocabularies\n",
|
||||||
|
" source_lang = Lang(source_lang)\n",
|
||||||
|
" target_lang = Lang(target_lang)\n",
|
||||||
|
" \n",
|
||||||
|
" for source_sentence, target_sentence in pairs:\n",
|
||||||
|
" source_lang.add_sentence(source_sentence)\n",
|
||||||
|
" target_lang.add_sentence(target_sentence)\n",
|
||||||
|
" \n",
|
||||||
|
" return source_lang, target_lang, pairs"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:52:09.731269100Z",
|
||||||
|
"start_time": "2024-05-23T07:52:09.711245600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "fcc216f173c8f688"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Prepare data for training"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "8811870cb738a9dd"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 37,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Convert sentence to list of indexes (with <bos> and <eos> tokens)\n",
|
||||||
|
"def indexesFromSentence(lang, sentence):\n",
|
||||||
|
" return [bos_token] + [lang.word2index[word] for word in sentence.split(' ')] + [eos_token]\n",
|
||||||
|
"\n",
|
||||||
|
"# Convert sentence to tensor of indexes (with <bos> and <eos> tokens)\n",
|
||||||
|
"def tensorFromSentence(lang, sentence):\n",
|
||||||
|
" return torch.tensor(indexesFromSentence(lang, sentence), dtype=torch.long, device=device).view(-1, 1)\n",
|
||||||
|
"\n",
|
||||||
|
"# Data loader\n",
|
||||||
|
"def prepareDataLoader(df, source_lang, target_lang, batch_size: int = 32, ascii: bool = False):\n",
|
||||||
|
" # Prepare data (vocabularies, pairs of sentences)\n",
|
||||||
|
" source_lang, target_lang, pairs = prepareData(df, source_lang, target_lang, ascii=ascii)\n",
|
||||||
|
" \n",
|
||||||
|
" # Get maximum length of sentence\n",
|
||||||
|
" MAX_LENGTH = max(np.max([len(sentence.split(' ')) for sentence in df['source']]), np.max([len(sentence.split(' ')) for sentence in df['target']])) + 2\n",
|
||||||
|
" \n",
|
||||||
|
" # Get number of pairs\n",
|
||||||
|
" n_pairs = len(pairs)\n",
|
||||||
|
" \n",
|
||||||
|
" # Initialize tensors (source and target)\n",
|
||||||
|
" source_indexes = np.zeros((n_pairs, MAX_LENGTH), dtype=np.int32)\n",
|
||||||
|
" target_indexes = np.zeros((n_pairs, MAX_LENGTH), dtype=np.int32)\n",
|
||||||
|
" \n",
|
||||||
|
" # Fill tensors\n",
|
||||||
|
" for idx, (source_sentence, target_sentence) in enumerate(pairs):\n",
|
||||||
|
" source_idx = indexesFromSentence(source_lang, source_sentence)\n",
|
||||||
|
" target_idx = indexesFromSentence(target_lang, target_sentence)\n",
|
||||||
|
" \n",
|
||||||
|
" source_indexes[idx, :len(source_idx)] = source_idx\n",
|
||||||
|
" target_indexes[idx, :len(target_idx)] = target_idx\n",
|
||||||
|
" \n",
|
||||||
|
" # Tensor dataset\n",
|
||||||
|
" train_data = TensorDataset(torch.LongTensor(source_indexes).to(device),\n",
|
||||||
|
" torch.LongTensor(target_indexes).to(device))\n",
|
||||||
|
" \n",
|
||||||
|
" # Sampler\n",
|
||||||
|
" train_sampler = RandomSampler(train_data)\n",
|
||||||
|
" \n",
|
||||||
|
" # Data loader\n",
|
||||||
|
" train_loader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)\n",
|
||||||
|
" \n",
|
||||||
|
" return source_lang, target_lang, pairs, train_loader, MAX_LENGTH"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T16:52:29.833983200Z",
|
||||||
|
"start_time": "2024-05-23T16:52:29.762233200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "776b2c6f672a8234"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Seq2Seq model - Encoder and Decoder"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "cf83019158526a0e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Encoder\n",
|
||||||
|
"class EncoderRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, input_size: int, hidden_size: int = 100, dropout: float = 0.1):\n",
|
||||||
|
" super(EncoderRNN, self).__init__()\n",
|
||||||
|
" self.hidden_size = hidden_size\n",
|
||||||
|
" \n",
|
||||||
|
" # Embedding layer\n",
|
||||||
|
" self.embedding = nn.Embedding(input_size, hidden_size)\n",
|
||||||
|
" \n",
|
||||||
|
" # GRU layer\n",
|
||||||
|
" self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Dropout layer for regularization\n",
|
||||||
|
" self.dropout = nn.Dropout(p=dropout)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, input):\n",
|
||||||
|
" # Transform input (as tensor of word indexes) to embeddings vectors\n",
|
||||||
|
" embedded = self.embedding(input)\n",
|
||||||
|
" \n",
|
||||||
|
" # Apply dropout to embeddings\n",
|
||||||
|
" embedded = self.dropout(embedded)\n",
|
||||||
|
" \n",
|
||||||
|
" # Pass embeddings through GRU and get output and hidden state\n",
|
||||||
|
" output, hidden = self.gru(embedded)\n",
|
||||||
|
" \n",
|
||||||
|
" return output, hidden"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:52:11.571658900Z",
|
||||||
|
"start_time": "2024-05-23T07:52:11.552337200Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "dea97063c24c0ad6"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# What is Tearcher Forcing - https://saturncloud.io/glossary/teacher-forcing/#:~:text=What%20is%20Teacher%20Forcing%3F,%2C%20translation%2C%20and%20text%20generation.\n",
|
||||||
|
"\n",
|
||||||
|
"class BahdanauAttention(nn.Module):\n",
|
||||||
|
" def __init__(self, hidden_size):\n",
|
||||||
|
" super(BahdanauAttention, self).__init__()\n",
|
||||||
|
" self.hidden_size = hidden_size\n",
|
||||||
|
" \n",
|
||||||
|
" # Linear layer to transform encoder output to attention weights\n",
|
||||||
|
" self.Wa = nn.Linear(hidden_size, hidden_size)\n",
|
||||||
|
" self.Ua = nn.Linear(hidden_size, hidden_size)\n",
|
||||||
|
" self.Va = nn.Linear(hidden_size, 1)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, query, keys):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param query: hidden state from decoder \n",
|
||||||
|
" :param keys: output from encoder\n",
|
||||||
|
" :return: context vector and attention weights\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" scores = self.Va(torch.tanh(self.Wa(query) + self.Ua(keys)))\n",
|
||||||
|
" scores = scores.squeeze(2).unsqueeze(1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Apply softmax to get attention weights\n",
|
||||||
|
" weights = F.softmax(scores, dim=-1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Calculate context vector\n",
|
||||||
|
" context = torch.bmm(weights, keys)\n",
|
||||||
|
" \n",
|
||||||
|
" return context, weights\n",
|
||||||
|
" \n",
|
||||||
|
"# Decoder\n",
|
||||||
|
"class DecoderRNN(nn.Module):\n",
|
||||||
|
" def __init__(self, hidden_size, output_size, dropout_p: float = 0.1):\n",
|
||||||
|
" super(DecoderRNN, self).__init__()\n",
|
||||||
|
" \n",
|
||||||
|
" # Embedding layer\n",
|
||||||
|
" self.embedding = nn.Embedding(output_size, hidden_size)\n",
|
||||||
|
" \n",
|
||||||
|
" # Attention layer\n",
|
||||||
|
" self.attention = BahdanauAttention(hidden_size)\n",
|
||||||
|
" \n",
|
||||||
|
" # GRU layer - input is concatenation of embeddings and context vector\n",
|
||||||
|
" self.gru = nn.GRU(2 * hidden_size, hidden_size, batch_first=True)\n",
|
||||||
|
" \n",
|
||||||
|
" # Linear layer to get output\n",
|
||||||
|
" self.out = nn.Linear(hidden_size, output_size)\n",
|
||||||
|
" \n",
|
||||||
|
" # Dropout layer for regularization\n",
|
||||||
|
" self.dropout = nn.Dropout(p=dropout_p)\n",
|
||||||
|
" \n",
|
||||||
|
" def forward(self, encoder_outputs, encoder_hidden, target_tensor=None):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param encoder_outputs: output from encoder\n",
|
||||||
|
" :param encoder_hidden: last hidden states from encoder, used as initial hidden states for decoder\n",
|
||||||
|
" :param target_tensor: target tensor - used in training with teacher forcing\n",
|
||||||
|
" :return: \n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Batch size\n",
|
||||||
|
" batch_size = encoder_outputs.size(0)\n",
|
||||||
|
" \n",
|
||||||
|
" # Decoder input - initialize with <bos> token index\n",
|
||||||
|
" decoder_input = torch.empty(batch_size, 1, dtype=torch.long, device=device).fill_(bos_token)\n",
|
||||||
|
" \n",
|
||||||
|
" # Decoder hidden state - initialize with encoder hidden state\n",
|
||||||
|
" decoder_hidden = encoder_hidden\n",
|
||||||
|
" \n",
|
||||||
|
" # List to store decoder outputs\n",
|
||||||
|
" decoder_outputs = []\n",
|
||||||
|
" \n",
|
||||||
|
" # List to store attention weights\n",
|
||||||
|
" attention_weights = []\n",
|
||||||
|
" \n",
|
||||||
|
" # Determine the maximum length of the sequence to generate\n",
|
||||||
|
" max_length = target_tensor.size(1) if target_tensor is not None else MAX_LENGTH\n",
|
||||||
|
" \n",
|
||||||
|
" # Decoder loop\n",
|
||||||
|
" for i in range(max_length):\n",
|
||||||
|
" # Forward step\n",
|
||||||
|
" decoder_output, decoder_hidden, attn_weights = self.forward_step(decoder_input, decoder_hidden, encoder_outputs)\n",
|
||||||
|
" \n",
|
||||||
|
" # Save output and attention weights\n",
|
||||||
|
" decoder_outputs.append(decoder_output)\n",
|
||||||
|
" attention_weights.append(attn_weights)\n",
|
||||||
|
" \n",
|
||||||
|
" # If target tensor is provided, use it for next input\n",
|
||||||
|
" if target_tensor is not None:\n",
|
||||||
|
" # Teacher forcing: next input is current target\n",
|
||||||
|
" decoder_input = target_tensor[:, i].unsqueeze(1)\n",
|
||||||
|
" else:\n",
|
||||||
|
" # Otherwise use output from current step (own prediction)\n",
|
||||||
|
" _, topi = decoder_output.topk(1)\n",
|
||||||
|
" decoder_input = topi.squeeze(-1).detach()\n",
|
||||||
|
" \n",
|
||||||
|
" # Break if decoder input is <eos> token\n",
|
||||||
|
" if torch.any(decoder_input == eos_token):\n",
|
||||||
|
" break\n",
|
||||||
|
" \n",
|
||||||
|
" # Concatenate outputs\n",
|
||||||
|
" decoder_outputs = torch.cat(decoder_outputs, dim=1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Apply log softmax to get probabilities\n",
|
||||||
|
" decoder_outputs = F.log_softmax(decoder_outputs, dim=-1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Concatenate attention weights\n",
|
||||||
|
" attention_weights = torch.cat(attention_weights, dim=1)\n",
|
||||||
|
" \n",
|
||||||
|
" return decoder_outputs, decoder_hidden, attention_weights\n",
|
||||||
|
" \n",
|
||||||
|
" \n",
|
||||||
|
" def forward_step(self, decoder_input, decoder_hidden, encoder_outputs):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" Forward step of decoder\n",
|
||||||
|
" :param decoder_input: current input tensor for decoder\n",
|
||||||
|
" :param decoder_hidden: current hidden state of decoder\n",
|
||||||
|
" :param encoder_outputs: output from encoder\n",
|
||||||
|
" :return: output and hidden state\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Transform input (as tensor of word indexes) to embeddings vectors\n",
|
||||||
|
" embedded = self.embedding(decoder_input)\n",
|
||||||
|
" \n",
|
||||||
|
" # Apply dropout to embeddings\n",
|
||||||
|
" embedded = self.dropout(embedded)\n",
|
||||||
|
" \n",
|
||||||
|
" # Query\n",
|
||||||
|
" query = decoder_hidden.permute(1, 0, 2)\n",
|
||||||
|
" \n",
|
||||||
|
" # Context vector and attention weights\n",
|
||||||
|
" context, attn_weights = self.attention(query, encoder_outputs)\n",
|
||||||
|
" \n",
|
||||||
|
" # Concatenate embeddings and context vector\n",
|
||||||
|
" input_gru = torch.cat((embedded, context), dim=2)\n",
|
||||||
|
" \n",
|
||||||
|
" # GRU\n",
|
||||||
|
" output, hidden = self.gru(input_gru, decoder_hidden)\n",
|
||||||
|
" \n",
|
||||||
|
" # Pass output through linear layer to get final output\n",
|
||||||
|
" output = self.out(output)\n",
|
||||||
|
" \n",
|
||||||
|
" return output, hidden, attn_weights"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T07:52:12.294694400Z",
|
||||||
|
"start_time": "2024-05-23T07:52:12.283155700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "8263b6cbafc89328"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Model training"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "85a962cdbcbd4361"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 66,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Model training\n",
|
||||||
|
"def train(dataloader, encoder, decoder, epochs: int = 100, learning_rate: float = 0.01, info_every: int = 10, plot_every: int = 10):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param dataloader: DataLoader with training data\n",
|
||||||
|
" :param encoder: Encoder model\n",
|
||||||
|
" :param decoder: Decoder model\n",
|
||||||
|
" :param epochs: Number of epochs\n",
|
||||||
|
" :param learning_rate: Learning rate\n",
|
||||||
|
" :param info_every: Specify how often to print information about training (default: every 10 epochs)\n",
|
||||||
|
" :param plot_every: Specify how often to plot loss (default: every 10 epochs)\n",
|
||||||
|
" :return: None\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Set models to training mode\n",
|
||||||
|
" encoder.train()\n",
|
||||||
|
" decoder.train()\n",
|
||||||
|
" \n",
|
||||||
|
" # Initialize optimizer\n",
|
||||||
|
" encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)\n",
|
||||||
|
" decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate)\n",
|
||||||
|
" \n",
|
||||||
|
" # Initialize loss function\n",
|
||||||
|
" criterion = nn.NLLLoss()\n",
|
||||||
|
" \n",
|
||||||
|
" # Initialize loss lists\n",
|
||||||
|
" plot_losses = []\n",
|
||||||
|
" \n",
|
||||||
|
" # Training loop\n",
|
||||||
|
" for epoch in tqdm(range(epochs + 1)):\n",
|
||||||
|
" total_loss = 0\n",
|
||||||
|
" \n",
|
||||||
|
" for data in dataloader:\n",
|
||||||
|
" source_tensor, target_tensor = data\n",
|
||||||
|
" \n",
|
||||||
|
" # Zero gradients\n",
|
||||||
|
" encoder_optimizer.zero_grad()\n",
|
||||||
|
" decoder_optimizer.zero_grad()\n",
|
||||||
|
" \n",
|
||||||
|
" # Forward pass\n",
|
||||||
|
" encoder_outputs, encoder_hidden = encoder(source_tensor)\n",
|
||||||
|
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden, target_tensor)\n",
|
||||||
|
" \n",
|
||||||
|
" # Calculate loss\n",
|
||||||
|
" loss = criterion(\n",
|
||||||
|
" decoder_outputs.view(-1, decoder_outputs.size(-1)),\n",
|
||||||
|
" target_tensor.view(-1)\n",
|
||||||
|
" )\n",
|
||||||
|
" \n",
|
||||||
|
" # Backward pass\n",
|
||||||
|
" loss.backward()\n",
|
||||||
|
" \n",
|
||||||
|
" # Update weights\n",
|
||||||
|
" encoder_optimizer.step()\n",
|
||||||
|
" decoder_optimizer.step()\n",
|
||||||
|
" \n",
|
||||||
|
" total_loss += loss.item()\n",
|
||||||
|
" \n",
|
||||||
|
" plot_losses.append(total_loss / len(dataloader))\n",
|
||||||
|
" \n",
|
||||||
|
" if epoch % info_every == 0:\n",
|
||||||
|
" print(f'Epoch: {epoch}, Loss: {total_loss / len(dataloader)}')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:09:17.257069500Z",
|
||||||
|
"start_time": "2024-05-23T17:09:17.231097900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "c690814243db9f65"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 67,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Model predictions\n",
|
||||||
|
"def predict(encoder, decoder, sentence, source_lang, target_lang):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param encoder: Encoder model\n",
|
||||||
|
" :param decoder: Decoder model\n",
|
||||||
|
" :param sentence: Sentence to translate\n",
|
||||||
|
" :param source_lang: Source language vocabulary\n",
|
||||||
|
" :param target_lang: Target language vocabulary\n",
|
||||||
|
" :return: predicted sentence\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Set models to evaluation mode\n",
|
||||||
|
" encoder.eval()\n",
|
||||||
|
" decoder.eval()\n",
|
||||||
|
" \n",
|
||||||
|
" with torch.no_grad():\n",
|
||||||
|
" # Prepare input tensor\n",
|
||||||
|
" input_tensor = tensorFromSentence(source_lang, sentence).view(1, -1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Forward pass\n",
|
||||||
|
" encoder_outputs, encoder_hidden = encoder(input_tensor)\n",
|
||||||
|
" decoder_outputs, _, _ = decoder(encoder_outputs, encoder_hidden)\n",
|
||||||
|
" \n",
|
||||||
|
" # Get indexes of the most probable words\n",
|
||||||
|
" _, topi = decoder_outputs.topk(1)\n",
|
||||||
|
" decoded_ids = topi.squeeze()\n",
|
||||||
|
" \n",
|
||||||
|
" # Check if tensor if zero-dimensional\n",
|
||||||
|
" if decoded_ids.dim() == 0:\n",
|
||||||
|
" decoded_ids = decoded_ids.view(1)\n",
|
||||||
|
" \n",
|
||||||
|
" # Convert indexes to words\n",
|
||||||
|
" decoded_words = []\n",
|
||||||
|
" \n",
|
||||||
|
" for idx in decoded_ids:\n",
|
||||||
|
" if idx.item() == eos_token:\n",
|
||||||
|
" decoded_words.append('<eos>')\n",
|
||||||
|
" break\n",
|
||||||
|
" decoded_words.append(target_lang.index2word[idx.item()])\n",
|
||||||
|
" \n",
|
||||||
|
" return decoded_words"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:09:30.412069400Z",
|
||||||
|
"start_time": "2024-05-23T17:09:30.387968700Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "4b3cf1614d7eac6e"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 73,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Random evaluation\n",
|
||||||
|
"def random_evaluation(encoder, decoder, n: int = 10):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param encoder: Encoder model\n",
|
||||||
|
" :param decoder: Decoder model\n",
|
||||||
|
" :param n: Number of sentences to evaluate\n",
|
||||||
|
" :return: None\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Set models to evaluation mode\n",
|
||||||
|
" encoder.eval()\n",
|
||||||
|
" decoder.eval()\n",
|
||||||
|
" \n",
|
||||||
|
" # Get random pairs and make predictions\n",
|
||||||
|
" for i in range(n):\n",
|
||||||
|
" pair = random.choice(pairs)\n",
|
||||||
|
" print('[source]>', pair[0])\n",
|
||||||
|
" print('[target]=', pair[1])\n",
|
||||||
|
" output_words = predict(encoder, decoder, pair[0], source_lang, target_lang)\n",
|
||||||
|
" output_words = list(filter(lambda x: x not in ['<bos>', '<eos>'], output_words))\n",
|
||||||
|
" output_sentence = ' '.join(output_words)\n",
|
||||||
|
" print('[prediction]<', output_sentence)\n",
|
||||||
|
" print('')"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:13:56.568407800Z",
|
||||||
|
"start_time": "2024-05-23T17:13:56.546161600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "82192424a14423cf"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 74,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# BLEU score\n",
|
||||||
|
"def calculate_bleu_score(encoder, decoder, pairs, source_lang, target_lang):\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" :param encoder: Encoder model\n",
|
||||||
|
" :param decoder: Decoder model\n",
|
||||||
|
" :param pairs: List of pairs of sentences\n",
|
||||||
|
" :param source_lang: Source language vocabulary\n",
|
||||||
|
" :param target_lang: Target language vocabulary\n",
|
||||||
|
" :return: BLEU score\n",
|
||||||
|
" \"\"\"\n",
|
||||||
|
" # Initialize lists for references and candidates\n",
|
||||||
|
" references = []\n",
|
||||||
|
" candidates = []\n",
|
||||||
|
" \n",
|
||||||
|
" # Loop through pairs\n",
|
||||||
|
" for idx, (source_sentence, target_sentence) in enumerate(pairs):\n",
|
||||||
|
" # Get predicted sentence\n",
|
||||||
|
" predicted_sentence = predict(encoder, decoder, source_sentence, source_lang, target_lang)\n",
|
||||||
|
" \n",
|
||||||
|
" # Remove <bos> and <eos> tokens\n",
|
||||||
|
" predicted_sentence = list(filter(lambda x: x not in ['<bos>', '<eos>'], predicted_sentence))\n",
|
||||||
|
" \n",
|
||||||
|
" # Add reference and candidate\n",
|
||||||
|
" references.append([target_sentence.split(' ')])\n",
|
||||||
|
" candidates.append(predicted_sentence)\n",
|
||||||
|
" \n",
|
||||||
|
" return bleu_score(candidates, references)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:13:57.245060700Z",
|
||||||
|
"start_time": "2024-05-23T17:13:57.198562600Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "e5dbb62e23d0f341"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Model parameters\n",
|
||||||
|
"hidden_size = 128\n",
|
||||||
|
"batch_size = 32\n",
|
||||||
|
"\n",
|
||||||
|
"# Prepare data\n",
|
||||||
|
"source_lang, target_lang, pairs, train_loader, MAX_LENGTH = prepareDataLoader(df, 'angielski', 'polski', batch_size=batch_size)\n",
|
||||||
|
"\n",
|
||||||
|
"# Initialize encoder and decoder\n",
|
||||||
|
"encoder = EncoderRNN(source_lang.n_words, hidden_size).to(device)\n",
|
||||||
|
"decoder = DecoderRNN(hidden_size, target_lang.n_words).to(device)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2024-05-23T07:52:13.338545Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "7313d7aaa632a02f"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 22,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": " 0%| | 0/1001 [00:00<?, ?it/s]",
|
||||||
|
"application/vnd.jupyter.widget-view+json": {
|
||||||
|
"version_major": 2,
|
||||||
|
"version_minor": 0,
|
||||||
|
"model_id": "7b0116c1e8ba4e06adeec22bcf807076"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "display_data"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Epoch: 0, Loss: 0.347686683293432\n",
|
||||||
|
"Epoch: 100, Loss: 0.2592399111017585\n",
|
||||||
|
"Epoch: 200, Loss: 0.2301939323078841\n",
|
||||||
|
"Epoch: 300, Loss: 0.246875268407166\n",
|
||||||
|
"Epoch: 400, Loss: 0.260721294907853\n",
|
||||||
|
"Epoch: 500, Loss: 0.258150483481586\n",
|
||||||
|
"Epoch: 600, Loss: 0.23601281619630754\n",
|
||||||
|
"Epoch: 700, Loss: 0.24906805180944502\n",
|
||||||
|
"Epoch: 800, Loss: 0.22962150094099343\n",
|
||||||
|
"Epoch: 900, Loss: 0.22537698200903833\n",
|
||||||
|
"Epoch: 1000, Loss: 0.22563873510807753\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Train the model\n",
|
||||||
|
"train(train_loader, encoder, decoder, epochs=100, learning_rate=0.01, info_every=2, plot_every=10)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"start_time": "2024-05-23T08:35:55.096786500Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "b155b364a0473d84"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"### Model evaluation"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false
|
||||||
|
},
|
||||||
|
"id": "64fa964a6efefc8d"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 78,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"[source]> i shouted\n",
|
||||||
|
"[target]= krzyknąłem\n",
|
||||||
|
"[prediction]< krzyknąłem\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> i ran\n",
|
||||||
|
"[target]= pobiegłem\n",
|
||||||
|
"[prediction]< pobiegłam\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> we danced\n",
|
||||||
|
"[target]= tańczyliśmy\n",
|
||||||
|
"[prediction]< jestem\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> i moved\n",
|
||||||
|
"[target]= ruszyłem się\n",
|
||||||
|
"[prediction]< jestem zły\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> tom swore\n",
|
||||||
|
"[target]= tom przysiągł\n",
|
||||||
|
"[prediction]< tom przysiągł\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> i failed\n",
|
||||||
|
"[target]= poniosłem klęskę\n",
|
||||||
|
"[prediction]< zawiodłem\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> unlock it\n",
|
||||||
|
"[target]= otwórz to\n",
|
||||||
|
"[prediction]< otwórz to\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> i inhaled\n",
|
||||||
|
"[target]= wciągnąłem powietrze\n",
|
||||||
|
"[prediction]< wciągnąłem powietrze\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> who won\n",
|
||||||
|
"[target]= kto wygrał\n",
|
||||||
|
"[prediction]< kto wygrał\n",
|
||||||
|
"\n",
|
||||||
|
"[source]> tom works\n",
|
||||||
|
"[target]= tom pracuje\n",
|
||||||
|
"[prediction]< tom pracuje\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# Random evaluation\n",
|
||||||
|
"random_evaluation(encoder, decoder, n=10)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:15:10.014018Z",
|
||||||
|
"start_time": "2024-05-23T17:15:09.866410900Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "d476fb8300ffc274"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 77,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"data": {
|
||||||
|
"text/plain": "0.5394429729650964"
|
||||||
|
},
|
||||||
|
"execution_count": 77,
|
||||||
|
"metadata": {},
|
||||||
|
"output_type": "execute_result"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"# BLEU score\n",
|
||||||
|
"calculate_bleu_score(encoder, decoder, pairs, source_lang, target_lang)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"ExecuteTime": {
|
||||||
|
"end_time": "2024-05-23T17:15:07.815302100Z",
|
||||||
|
"start_time": "2024-05-23T17:15:00.325393Z"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"id": "2dc143478e0b7e2c"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "Python 3",
|
||||||
|
"language": "python",
|
||||||
|
"name": "python3"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 2
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython2",
|
||||||
|
"version": "2.7.6"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user