669 lines
23 KiB
Plaintext
669 lines
23 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 1,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Notebook bazuje na \n",
|
|
"# https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/3%20-%20Faster%20Sentiment%20Analysis.ipynb"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"text/html": [
|
|
"<style>.container { width:100% !important; }</style>"
|
|
],
|
|
"text/plain": [
|
|
"<IPython.core.display.HTML object>"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"from IPython.core.display import display, HTML\n",
|
|
"display(HTML(\"<style>.container { width:100% !important; }</style>\"))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Collecting package metadata (current_repodata.json): done\n",
|
|
"Solving environment: done\n",
|
|
"\n",
|
|
"\n",
|
|
"==> WARNING: A newer version of conda exists. <==\n",
|
|
" current version: 4.8.3\n",
|
|
" latest version: 4.9.2\n",
|
|
"\n",
|
|
"Please update conda by running\n",
|
|
"\n",
|
|
" $ conda update -n base -c defaults conda\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# All requested packages already installed.\n",
|
|
"\n",
|
|
"Collecting package metadata (current_repodata.json): done\n",
|
|
"Solving environment: done\n",
|
|
"\n",
|
|
"\n",
|
|
"==> WARNING: A newer version of conda exists. <==\n",
|
|
" current version: 4.8.3\n",
|
|
" latest version: 4.9.2\n",
|
|
"\n",
|
|
"Please update conda by running\n",
|
|
"\n",
|
|
" $ conda update -n base -c defaults conda\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# All requested packages already installed.\n",
|
|
"\n",
|
|
"Requirement already satisfied: en_core_web_sm==2.3.1 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm==2.3.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (2.3.1)\n",
|
|
"Requirement already satisfied: spacy<2.4.0,>=2.3.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from en_core_web_sm==2.3.1) (2.3.2)\n",
|
|
"Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.0.4)\n",
|
|
"Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.2)\n",
|
|
"Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.25.0)\n",
|
|
"Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.8.0)\n",
|
|
"Requirement already satisfied: numpy>=1.15.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.19.2)\n",
|
|
"Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)\n",
|
|
"Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5)\n",
|
|
"Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.0)\n",
|
|
"Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.54.1)\n",
|
|
"Requirement already satisfied: plac<1.2.0,>=0.9.6 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.9.6)\n",
|
|
"Requirement already satisfied: blis<0.5.0,>=0.4.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.4.1)\n",
|
|
"Requirement already satisfied: setuptools in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (50.3.1.post20201107)\n",
|
|
"Requirement already satisfied: thinc==7.4.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (7.4.1)\n",
|
|
"Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2020.12.5)\n",
|
|
"Requirement already satisfied: chardet<4,>=3.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.4)\n",
|
|
"Requirement already satisfied: idna<3,>=2.5 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.10)\n",
|
|
"Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.25.11)\n",
|
|
"\u001b[38;5;2m✔ Download and installation successful\u001b[0m\n",
|
|
"You can now load the model via spacy.load('en_core_web_sm')\n",
|
|
"\u001b[38;5;2m✔ Linking successful\u001b[0m\n",
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/en_core_web_sm -->\n",
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/spacy/data/en\n",
|
|
"You can now load the model via spacy.load('en')\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"!conda install torchtext -c pytorch -y\n",
|
|
"!conda install spacy -y\n",
|
|
"!python -m spacy download en"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
|
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
|
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torchtext import data\n",
|
|
"from torchtext import datasets\n",
|
|
"\n",
|
|
"SEED = 1234\n",
|
|
"\n",
|
|
"torch.manual_seed(SEED)\n",
|
|
"torch.backends.cudnn.deterministic = True\n",
|
|
"\n",
|
|
"TEXT = data.Field(tokenize = 'spacy')\n",
|
|
"LABEL = data.LabelField(dtype = torch.float)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
|
" warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import random\n",
|
|
"\n",
|
|
"train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
|
|
"\n",
|
|
"train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Number of training examples: 17500\n",
|
|
"Number of validation examples: 7500\n",
|
|
"Number of testing examples: 25000\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f'Number of training examples: {len(train_data)}')\n",
|
|
"print(f'Number of validation examples: {len(valid_data)}')\n",
|
|
"print(f'Number of testing examples: {len(test_data)}')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"{'text': ['Why', 'do', 'people', 'who', 'do', 'not', 'know', 'what', 'a', 'particular', 'time', 'in', 'the', 'past', 'was', 'like', 'feel', 'the', 'need', 'to', 'try', 'to', 'define', 'that', 'time', 'for', 'others', '?', 'Replace', 'Woodstock', 'with', 'the', 'Civil', 'War', 'and', 'the', 'Apollo', 'moon', '-', 'landing', 'with', 'the', 'Titanic', 'sinking', 'and', 'you', \"'ve\", 'got', 'as', 'realistic', 'a', 'flick', 'as', 'this', 'formulaic', 'soap', 'opera', 'populated', 'entirely', 'by', 'low', '-', 'life', 'trash', '.', 'Is', 'this', 'what', 'kids', 'who', 'were', 'too', 'young', 'to', 'be', 'allowed', 'to', 'go', 'to', 'Woodstock', 'and', 'who', 'failed', 'grade', 'school', 'composition', 'do', '?', '\"', 'I', \"'ll\", 'show', 'those', 'old', 'meanies', ',', 'I', \"'ll\", 'put', 'out', 'my', 'own', 'movie', 'and', 'prove', 'that', 'you', 'do', \"n't\", 'have', 'to', 'know', 'nuttin', 'about', 'your', 'topic', 'to', 'still', 'make', 'money', '!', '\"', 'Yeah', ',', 'we', 'already', 'know', 'that', '.', 'The', 'one', 'thing', 'watching', 'this', 'film', 'did', 'for', 'me', 'was', 'to', 'give', 'me', 'a', 'little', 'insight', 'into', 'underclass', 'thinking', '.', 'The', 'next', 'time', 'I', 'see', 'a', 'slut', 'in', 'a', 'bar', 'who', 'looks', 'like', 'Diane', 'Lane', ',', 'I', \"'m\", 'running', 'the', 'other', 'way', '.', 'It', \"'s\", 'child', 'abuse', 'to', 'let', 'parents', 'that', 'worthless', 'raise', 'kids', '.', 'It', \"'s\", 'audience', 'abuse', 'to', 'simply', 'stick', 'Woodstock', 'and', 'the', 'moonlanding', 'into', 'a', 'flick', 'as', 'if', 'that', 'ipso', 'facto', 'means', 'the', 'film', 'portrays', '1969', '.'], 'label': 'neg'}\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(vars(train_data.examples[0]))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 8,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"MAX_VOCAB_SIZE = 25_000\n",
|
|
"\n",
|
|
"TEXT.build_vocab(train_data, max_size = MAX_VOCAB_SIZE)\n",
|
|
"\n",
|
|
"LABEL.build_vocab(train_data)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Unique tokens in TEXT vocabulary: 25002\n",
|
|
"Unique tokens in LABEL vocabulary: 2\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(f\"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}\")\n",
|
|
"print(f\"Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 10,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[('the', 203172), (',', 192039), ('.', 165889), ('a', 109265), ('and', 109192), ('of', 100241), ('to', 93511), ('is', 76322), ('in', 61299), ('I', 54013), ('it', 53609), ('that', 48928), ('\"', 44101), (\"'s\", 43213), ('this', 42383), ('-', 36691), ('/><br', 35471), ('was', 34989), ('as', 30252), ('with', 30012)]\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(TEXT.vocab.freqs.most_common(20))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['<unk>', '<pad>', 'the', ',', '.', 'a', 'and', 'of', 'to', 'is']\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(TEXT.vocab.itos[:10])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 12,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"defaultdict(None, {'neg': 0, 'pos': 1})\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"print(LABEL.vocab.stoi)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
|
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"BATCH_SIZE = 64\n",
|
|
"\n",
|
|
"device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
|
|
"\n",
|
|
"train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
|
|
" (train_data, valid_data, test_data), \n",
|
|
" batch_size = BATCH_SIZE, \n",
|
|
" device = device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.nn as nn\n",
|
|
"import torch.nn.functional as F\n",
|
|
"\n",
|
|
"class FastText(nn.Module):\n",
|
|
" def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):\n",
|
|
" \n",
|
|
" super().__init__()\n",
|
|
" \n",
|
|
" self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)\n",
|
|
" \n",
|
|
" self.fc = nn.Linear(embedding_dim, output_dim)\n",
|
|
" \n",
|
|
" def forward(self, text):\n",
|
|
" \n",
|
|
" #text = [sent len, batch size]\n",
|
|
" \n",
|
|
" embedded = self.embedding(text)\n",
|
|
" \n",
|
|
" #embedded = [sent len, batch size, emb dim]\n",
|
|
" \n",
|
|
" embedded = embedded.permute(1, 0, 2)\n",
|
|
" \n",
|
|
" #embedded = [batch size, sent len, emb dim]\n",
|
|
" \n",
|
|
" pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) \n",
|
|
" \n",
|
|
" #pooled = [batch size, embedding_dim]\n",
|
|
" \n",
|
|
" return torch.sigmoid(self.fc(pooled))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"INPUT_DIM = len(TEXT.vocab)\n",
|
|
"EMBEDDING_DIM = 100\n",
|
|
"OUTPUT_DIM = 1\n",
|
|
"PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
|
|
"\n",
|
|
"model = FastText(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 16,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The model has 2,500,301 trainable parameters\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"def count_parameters(model):\n",
|
|
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
|
|
"\n",
|
|
"print(f'The model has {count_parameters(model):,} trainable parameters')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 17,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
|
|
"\n",
|
|
"model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
|
|
"model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch.optim as optim\n",
|
|
"\n",
|
|
"optimizer = optim.Adam(model.parameters())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"criterion = nn.BCELoss()\n",
|
|
"\n",
|
|
"model = model.to(device)\n",
|
|
"criterion = criterion.to(device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def binary_accuracy(preds, y):\n",
|
|
" \"\"\"\n",
|
|
" Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
|
|
" \"\"\"\n",
|
|
"\n",
|
|
" #round predictions to the closest integer\n",
|
|
" rounded_preds = torch.round(preds)\n",
|
|
" correct = (rounded_preds == y).float() #convert into float for division \n",
|
|
" acc = correct.sum() / len(correct)\n",
|
|
" return acc"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def train(model, iterator, optimizer, criterion):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.train()\n",
|
|
" \n",
|
|
" for batch in iterator:\n",
|
|
" \n",
|
|
" optimizer.zero_grad()\n",
|
|
" \n",
|
|
" predictions = model(batch.text).squeeze(1)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, batch.label)\n",
|
|
" \n",
|
|
" acc = binary_accuracy(predictions, batch.label)\n",
|
|
" \n",
|
|
" loss.backward()\n",
|
|
" \n",
|
|
" optimizer.step()\n",
|
|
" \n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
" \n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"def evaluate(model, iterator, criterion):\n",
|
|
" \n",
|
|
" epoch_loss = 0\n",
|
|
" epoch_acc = 0\n",
|
|
" \n",
|
|
" model.eval()\n",
|
|
" \n",
|
|
" with torch.no_grad():\n",
|
|
" \n",
|
|
" for batch in iterator:\n",
|
|
"\n",
|
|
" predictions = model(batch.text).squeeze(1)\n",
|
|
" \n",
|
|
" loss = criterion(predictions, batch.label)\n",
|
|
" \n",
|
|
" acc = binary_accuracy(predictions, batch.label)\n",
|
|
"\n",
|
|
" epoch_loss += loss.item()\n",
|
|
" epoch_acc += acc.item()\n",
|
|
" \n",
|
|
" return epoch_loss / len(iterator), epoch_acc / len(iterator)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"\n",
|
|
"def epoch_time(start_time, end_time):\n",
|
|
" elapsed_time = end_time - start_time\n",
|
|
" elapsed_mins = int(elapsed_time / 60)\n",
|
|
" elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
|
|
" return elapsed_mins, elapsed_secs"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {
|
|
"scrolled": true
|
|
},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
|
|
" warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Epoch: 01 | Epoch Time: 0m 40s\n",
|
|
"\tTrain Loss: 0.686 | Train Acc: 59.60%\n",
|
|
"\t Val. Loss: 0.630 | Val. Acc: 67.82%\n",
|
|
"Epoch: 02 | Epoch Time: 0m 37s\n",
|
|
"\tTrain Loss: 0.639 | Train Acc: 74.52%\n",
|
|
"\t Val. Loss: 0.502 | Val. Acc: 75.98%\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"N_EPOCHS = 3\n",
|
|
"\n",
|
|
"best_valid_loss = float('inf')\n",
|
|
"\n",
|
|
"for epoch in range(N_EPOCHS):\n",
|
|
"\n",
|
|
" start_time = time.time()\n",
|
|
" \n",
|
|
" train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
|
|
" valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
|
|
" \n",
|
|
" end_time = time.time()\n",
|
|
"\n",
|
|
" epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
|
|
" \n",
|
|
" if valid_loss < best_valid_loss:\n",
|
|
" best_valid_loss = valid_loss\n",
|
|
" torch.save(model.state_dict(), 'model.pt')\n",
|
|
" \n",
|
|
" print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
|
|
" print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
|
|
" print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model.load_state_dict(torch.load('model.pt'))\n",
|
|
"\n",
|
|
"test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
|
|
"\n",
|
|
"print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# User Input"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import spacy\n",
|
|
"nlp = spacy.load('en')\n",
|
|
"\n",
|
|
"def predict_sentiment(model, sentence):\n",
|
|
" model.eval()\n",
|
|
" tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
|
|
" indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
|
|
" tensor = torch.LongTensor(indexed).to(device)\n",
|
|
" tensor = tensor.unsqueeze(1)\n",
|
|
" prediction = model(tensor)\n",
|
|
" return prediction.item()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"An example negative review..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predict_sentiment(model, \"This film is terrible\")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"An example positive review..."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"predict_sentiment(model, \"This film is great\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.0"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|