diff --git a/sentiment_analysis_embed_ff.ipynb b/sentiment_analysis_embed_ff.ipynb
new file mode 100644
index 0000000..b9a65ce
--- /dev/null
+++ b/sentiment_analysis_embed_ff.ipynb
@@ -0,0 +1,648 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Notebook bazuje na \n",
+ "# https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/3%20-%20Faster%20Sentiment%20Analysis.ipynb"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#conda install torchtext -c pytorch\n",
+ "#conda install spacy\n",
+ "#python -m spacy download en"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/media/kuba/ssd/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
+ " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n",
+ "/media/kuba/ssd/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
+ " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "from torchtext import data\n",
+ "from torchtext import datasets\n",
+ "\n",
+ "SEED = 1234\n",
+ "\n",
+ "torch.manual_seed(SEED)\n",
+ "torch.backends.cudnn.deterministic = True\n",
+ "\n",
+ "TEXT = data.Field(tokenize = 'spacy')\n",
+ "LABEL = data.LabelField(dtype = torch.float)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/media/kuba/ssd/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
+ " warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import random\n",
+ "\n",
+ "train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)\n",
+ "\n",
+ "train_data, valid_data = train_data.split(random_state = random.seed(SEED))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of training examples: 17500\n",
+ "Number of validation examples: 7500\n",
+ "Number of testing examples: 25000\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f'Number of training examples: {len(train_data)}')\n",
+ "print(f'Number of validation examples: {len(valid_data)}')\n",
+ "print(f'Number of testing examples: {len(test_data)}')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'text': ['...', 'through', 'the', 'similarly', 'minded', 'antics', 'of', 'Eric', 'Stanze', '.', 'A', 'not', '-', 'particularly', 'talented', 'director', 'has', 'helmed', 'a', 'not', '-', 'particularly', 'good', 'movie', ',', 'yet', 'I', 'still', 'found', 'myself', 'sitting', 'through', 'it', 'to', 'the', 'closing', 'credits', ',', 'if', 'for', 'nothing', 'more', 'than', 'to', 'see', 'what', 'happens', 'next.
A', 'rapist', 'escapes', 'from', 'prison', 'and', 'calls', 'up', 'his', 'old', 'flame', '.', 'After', 'capturing', 'her', '(', 'even', 'though', 'she', 'came', 'willingly', ')', 'and', 'threatening', 'her', 'into', 'having', 'sex', '(', 'another', 'event', 'she', 'was', 'also', 'willing', 'to', 'do', ')', 'he', 'reveals', 'that', 'he', 'has', 'kidnapped', 'three', 'guys', 'who', 'wronged', 'her', 'in', 'the', 'past', '.', 'He', 'then', 'decides', 'to', 'kill', 'her', '(', 'huh', '?', ')', 'but', 'is', 'foiled', 'and', 'dies', 'instead', '.', 'The', 'girl', \"'s\", 'mind', 'snaps', '(', 'or', 'something', 'like', 'that', ')', 'and', 'she', 'takes', 'out', 'her', 'rage', 'on', 'the', 'unlucky', 'chaps', 'in', 'the', 'basement.
Alright', ',', 'the', 'writing', 'sucks', ':', 'it', \"'s\", 'long', 'winded', ',', 'loaded', 'with', 'ten', '-', 'cent', 'words', 'and', 'there', 'is', 'WAY', 'too', 'much', 'of', 'it.
The', 'acting', 'sucks', ':', 'what', 'a', 'minute', ',', 'what', 'acting', '?', '<', 'br', '/>
The', 'filming', 'sucks', ':', 'home', 'video', 'is', 'bad', 'enough', ',', 'but', '20', 'minutes', 'of', 'graveyard', 'footage', 'is', 'just', 'a', 'damn', 'insult.
And', 'the', 'budget', 'is', 'a', 'joke', ':', 'get', 'it', '...', \"'budget\", \"'\", ',', 'that', 'was', 'the', 'punchline.
And', 'yet', 'there', 'was', 'a', 'charm', 'to', 'the', 'thing', '.', 'Back', 'in', 'the', '70', \"'s\", 'these', 'kind', 'of', 'movies', 'came', 'out', 'in', 'theatres', 'with', 'actual', 'budgets', 'and', 'talent', 'attached', 'to', 'them', ',', 'not', 'in', 'this', 'day', 'and', 'age', 'though', '.', 'If', 'you', 'want', 'to', 'watch', 'this', 'kind', 'of', 'violent', ',', 'sexually', 'exploitive', 'trash', '(', 'do', \"n't\", 'lie', ',', 'some', 'of', 'us', 'do', ')', 'then', 'this', 'is', 'all', 'your', 'gon', 'na', 'get', 'nowadays.
Some', 'brief', 'hardcore', 'shots', 'in', 'a', 'sex', 'scene', ',', 'torture', 'with', 'fecal', 'material', ',', 'fun', 'with', 'axes', ',', 'anal', 'rape', 'by', 'broom', 'stick', 'and', 'a', 'lengthy', 'shot', 'of', 'the', 'crazy', 'chick', 'masturbating', 'with', 'the', 'same', 'broom', 'stick', 'are', 'some', 'of', 'the', 'better', 'items', 'on', 'the', 'menu.
It', \"'s\", 'not', 'good', 'and', 'it', 'wo', \"n't\", 'be', 'remembered', ',', 'but', 'not', 'since', 'the', 'heyday', 'of', 'Joe', \"D'amato\", 'have', 'people', 'made', 'movies', 'like', 'this.
4/10'], 'label': 'neg'}\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(vars(train_data.examples[0]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "MAX_VOCAB_SIZE = 25_000\n",
+ "\n",
+ "TEXT.build_vocab(train_data, max_size = MAX_VOCAB_SIZE)\n",
+ "\n",
+ "LABEL.build_vocab(train_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Unique tokens in TEXT vocabulary: 25002\n",
+ "Unique tokens in LABEL vocabulary: 2\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}\")\n",
+ "print(f\"Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[('the', 202389), (',', 192527), ('.', 165463), ('a', 109375), ('and', 109303), ('of', 100836), ('to', 93959), ('is', 76223), ('in', 61140), ('I', 54434), ('it', 53612), ('that', 49147), ('\"', 44429), (\"'s\", 43357), ('this', 42421), ('-', 37080), ('/>
', '', 'the', ',', '.', 'a', 'and', 'of', 'to', 'is']\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(TEXT.vocab.itos[:10])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "defaultdict(None, {'neg': 0, 'pos': 1})\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(LABEL.vocab.stoi)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/media/kuba/ssd/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
+ " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
+ ]
+ }
+ ],
+ "source": [
+ "BATCH_SIZE = 64\n",
+ "\n",
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
+ "\n",
+ "train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(\n",
+ " (train_data, valid_data, test_data), \n",
+ " batch_size = BATCH_SIZE, \n",
+ " device = device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "class FastText(nn.Module):\n",
+ " def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):\n",
+ " \n",
+ " super().__init__()\n",
+ " \n",
+ " self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)\n",
+ " \n",
+ " self.fc = nn.Linear(embedding_dim, output_dim)\n",
+ " \n",
+ " def forward(self, text):\n",
+ " \n",
+ " #text = [sent len, batch size]\n",
+ " \n",
+ " embedded = self.embedding(text)\n",
+ " \n",
+ " #embedded = [sent len, batch size, emb dim]\n",
+ " \n",
+ " embedded = embedded.permute(1, 0, 2)\n",
+ " \n",
+ " #embedded = [batch size, sent len, emb dim]\n",
+ " \n",
+ " pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1) \n",
+ " \n",
+ " #pooled = [batch size, embedding_dim]\n",
+ " \n",
+ " return torch.sigmoid(self.fc(pooled))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "INPUT_DIM = len(TEXT.vocab)\n",
+ "EMBEDDING_DIM = 100\n",
+ "OUTPUT_DIM = 1\n",
+ "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n",
+ "\n",
+ "model = FastText(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The model has 2,500,301 trainable parameters\n"
+ ]
+ }
+ ],
+ "source": [
+ "def count_parameters(model):\n",
+ " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
+ "\n",
+ "print(f'The model has {count_parameters(model):,} trainable parameters')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]\n",
+ "\n",
+ "model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)\n",
+ "model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch.optim as optim\n",
+ "\n",
+ "optimizer = optim.Adam(model.parameters())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "criterion = nn.BCELoss()\n",
+ "\n",
+ "model = model.to(device)\n",
+ "criterion = criterion.to(device)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def binary_accuracy(preds, y):\n",
+ " \"\"\"\n",
+ " Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8\n",
+ " \"\"\"\n",
+ "\n",
+ " #round predictions to the closest integer\n",
+ " rounded_preds = torch.round(preds)\n",
+ " correct = (rounded_preds == y).float() #convert into float for division \n",
+ " acc = correct.sum() / len(correct)\n",
+ " return acc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def train(model, iterator, optimizer, criterion):\n",
+ " \n",
+ " epoch_loss = 0\n",
+ " epoch_acc = 0\n",
+ " \n",
+ " model.train()\n",
+ " \n",
+ " for batch in iterator:\n",
+ " \n",
+ " optimizer.zero_grad()\n",
+ " \n",
+ " predictions = model(batch.text).squeeze(1)\n",
+ " \n",
+ " loss = criterion(predictions, batch.label)\n",
+ " \n",
+ " acc = binary_accuracy(predictions, batch.label)\n",
+ " \n",
+ " loss.backward()\n",
+ " \n",
+ " optimizer.step()\n",
+ " \n",
+ " epoch_loss += loss.item()\n",
+ " epoch_acc += acc.item()\n",
+ " \n",
+ " return epoch_loss / len(iterator), epoch_acc / len(iterator)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def evaluate(model, iterator, criterion):\n",
+ " \n",
+ " epoch_loss = 0\n",
+ " epoch_acc = 0\n",
+ " \n",
+ " model.eval()\n",
+ " \n",
+ " with torch.no_grad():\n",
+ " \n",
+ " for batch in iterator:\n",
+ "\n",
+ " predictions = model(batch.text).squeeze(1)\n",
+ " \n",
+ " loss = criterion(predictions, batch.label)\n",
+ " \n",
+ " acc = binary_accuracy(predictions, batch.label)\n",
+ "\n",
+ " epoch_loss += loss.item()\n",
+ " epoch_acc += acc.item()\n",
+ " \n",
+ " return epoch_loss / len(iterator), epoch_acc / len(iterator)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import time\n",
+ "\n",
+ "def epoch_time(start_time, end_time):\n",
+ " elapsed_time = end_time - start_time\n",
+ " elapsed_mins = int(elapsed_time / 60)\n",
+ " elapsed_secs = int(elapsed_time - (elapsed_mins * 60))\n",
+ " return elapsed_mins, elapsed_secs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/media/kuba/ssd/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.\n",
+ " warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Epoch: 01 | Epoch Time: 0m 3s\n",
+ "\tTrain Loss: 0.685 | Train Acc: 59.98%\n",
+ "\t Val. Loss: 0.625 | Val. Acc: 68.93%\n",
+ "Epoch: 02 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.639 | Train Acc: 73.45%\n",
+ "\t Val. Loss: 0.513 | Val. Acc: 75.19%\n",
+ "Epoch: 03 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.560 | Train Acc: 79.02%\n",
+ "\t Val. Loss: 0.453 | Val. Acc: 80.60%\n",
+ "Epoch: 04 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.482 | Train Acc: 83.41%\n",
+ "\t Val. Loss: 0.410 | Val. Acc: 84.11%\n",
+ "Epoch: 05 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.420 | Train Acc: 86.42%\n",
+ "\t Val. Loss: 0.407 | Val. Acc: 86.05%\n",
+ "Epoch: 06 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.372 | Train Acc: 88.33%\n",
+ "\t Val. Loss: 0.432 | Val. Acc: 87.06%\n",
+ "Epoch: 07 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.333 | Train Acc: 89.47%\n",
+ "\t Val. Loss: 0.459 | Val. Acc: 87.87%\n",
+ "Epoch: 08 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.303 | Train Acc: 90.54%\n",
+ "\t Val. Loss: 0.480 | Val. Acc: 88.36%\n",
+ "Epoch: 09 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.276 | Train Acc: 91.32%\n",
+ "\t Val. Loss: 0.499 | Val. Acc: 88.69%\n",
+ "Epoch: 10 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.258 | Train Acc: 91.96%\n",
+ "\t Val. Loss: 0.518 | Val. Acc: 88.91%\n",
+ "Epoch: 11 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.239 | Train Acc: 92.55%\n",
+ "\t Val. Loss: 0.547 | Val. Acc: 89.06%\n",
+ "Epoch: 12 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.224 | Train Acc: 93.07%\n",
+ "\t Val. Loss: 0.565 | Val. Acc: 89.14%\n",
+ "Epoch: 13 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.209 | Train Acc: 93.58%\n",
+ "\t Val. Loss: 0.580 | Val. Acc: 89.26%\n",
+ "Epoch: 14 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.198 | Train Acc: 94.03%\n",
+ "\t Val. Loss: 0.656 | Val. Acc: 89.36%\n",
+ "Epoch: 15 | Epoch Time: 0m 2s\n",
+ "\tTrain Loss: 0.183 | Train Acc: 94.53%\n",
+ "\t Val. Loss: 0.704 | Val. Acc: 89.48%\n"
+ ]
+ }
+ ],
+ "source": [
+ "N_EPOCHS = 15\n",
+ "\n",
+ "best_valid_loss = float('inf')\n",
+ "\n",
+ "for epoch in range(N_EPOCHS):\n",
+ "\n",
+ " start_time = time.time()\n",
+ " \n",
+ " train_loss, train_acc = train(model, train_iterator, optimizer, criterion)\n",
+ " valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)\n",
+ " \n",
+ " end_time = time.time()\n",
+ "\n",
+ " epoch_mins, epoch_secs = epoch_time(start_time, end_time)\n",
+ " \n",
+ " if valid_loss < best_valid_loss:\n",
+ " best_valid_loss = valid_loss\n",
+ " torch.save(model.state_dict(), 'model.pt')\n",
+ " \n",
+ " print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')\n",
+ " print(f'\\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')\n",
+ " print(f'\\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Test Loss: 0.434 | Test Acc: 82.76%\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.load_state_dict(torch.load('tut3-model.pt'))\n",
+ "\n",
+ "test_loss, test_acc = evaluate(model, test_iterator, criterion)\n",
+ "\n",
+ "print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# User Input"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import spacy\n",
+ "nlp = spacy.load('en')\n",
+ "\n",
+ "def predict_sentiment(model, sentence):\n",
+ " model.eval()\n",
+ " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n",
+ " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n",
+ " tensor = torch.LongTensor(indexed).to(device)\n",
+ " tensor = tensor.unsqueeze(1)\n",
+ " prediction = model(tensor)\n",
+ " return prediction.item()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "An example negative review..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "8.701013030076865e-06"
+ ]
+ },
+ "execution_count": 26,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict_sentiment(model, \"This film is terrible\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "An example positive review..."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1.0"
+ ]
+ },
+ "execution_count": 27,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "predict_sentiment(model, \"This film is great\")"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}