23 KiB
23 KiB
# Notebook bazuje na
# https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/3%20-%20Faster%20Sentiment%20Analysis.ipynb
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
!conda install torchtext -c pytorch -y
!conda install spacy -y
!python -m spacy download en
Collecting package metadata (current_repodata.json): done Solving environment: done ==> WARNING: A newer version of conda exists. <== current version: 4.8.3 latest version: 4.9.2 Please update conda by running $ conda update -n base -c defaults conda # All requested packages already installed. Collecting package metadata (current_repodata.json): done Solving environment: done ==> WARNING: A newer version of conda exists. <== current version: 4.8.3 latest version: 4.9.2 Please update conda by running $ conda update -n base -c defaults conda # All requested packages already installed. Requirement already satisfied: en_core_web_sm==2.3.1 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.3.1/en_core_web_sm-2.3.1.tar.gz#egg=en_core_web_sm==2.3.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (2.3.1) Requirement already satisfied: spacy<2.4.0,>=2.3.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from en_core_web_sm==2.3.1) (2.3.2) Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.0.4) Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.2) Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.25.0) Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.8.0) Requirement already satisfied: numpy>=1.15.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.19.2) Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5) Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.5) Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.0.0) Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (4.54.1) Requirement already satisfied: plac<1.2.0,>=0.9.6 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.9.6) Requirement already satisfied: blis<0.5.0,>=0.4.0 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (0.4.1) Requirement already satisfied: setuptools in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (50.3.1.post20201107) Requirement already satisfied: thinc==7.4.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (7.4.1) Requirement already satisfied: certifi>=2017.4.17 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2020.12.5) Requirement already satisfied: chardet<4,>=3.0.2 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (3.0.4) Requirement already satisfied: idna<3,>=2.5 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (2.10) Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages (from requests<3.0.0,>=2.13.0->spacy<2.4.0,>=2.3.0->en_core_web_sm==2.3.1) (1.25.11) [38;5;2m✔ Download and installation successful[0m You can now load the model via spacy.load('en_core_web_sm') [38;5;2m✔ Linking successful[0m /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/en_core_web_sm --> /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/spacy/data/en You can now load the model via spacy.load('en')
import torch
from torchtext import data
from torchtext import datasets
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
TEXT = data.Field(tokenize = 'spacy')
LABEL = data.LabelField(dtype = torch.float)
/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: Field class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information. warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning) /home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/field.py:150: UserWarning: LabelField class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information. warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)
import random
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(random_state = random.seed(SEED))
/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/example.py:78: UserWarning: Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information. warnings.warn('Example class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.', UserWarning)
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')
Number of training examples: 17500 Number of validation examples: 7500 Number of testing examples: 25000
print(vars(train_data.examples[0]))
{'text': ['Why', 'do', 'people', 'who', 'do', 'not', 'know', 'what', 'a', 'particular', 'time', 'in', 'the', 'past', 'was', 'like', 'feel', 'the', 'need', 'to', 'try', 'to', 'define', 'that', 'time', 'for', 'others', '?', 'Replace', 'Woodstock', 'with', 'the', 'Civil', 'War', 'and', 'the', 'Apollo', 'moon', '-', 'landing', 'with', 'the', 'Titanic', 'sinking', 'and', 'you', "'ve", 'got', 'as', 'realistic', 'a', 'flick', 'as', 'this', 'formulaic', 'soap', 'opera', 'populated', 'entirely', 'by', 'low', '-', 'life', 'trash', '.', 'Is', 'this', 'what', 'kids', 'who', 'were', 'too', 'young', 'to', 'be', 'allowed', 'to', 'go', 'to', 'Woodstock', 'and', 'who', 'failed', 'grade', 'school', 'composition', 'do', '?', '"', 'I', "'ll", 'show', 'those', 'old', 'meanies', ',', 'I', "'ll", 'put', 'out', 'my', 'own', 'movie', 'and', 'prove', 'that', 'you', 'do', "n't", 'have', 'to', 'know', 'nuttin', 'about', 'your', 'topic', 'to', 'still', 'make', 'money', '!', '"', 'Yeah', ',', 'we', 'already', 'know', 'that', '.', 'The', 'one', 'thing', 'watching', 'this', 'film', 'did', 'for', 'me', 'was', 'to', 'give', 'me', 'a', 'little', 'insight', 'into', 'underclass', 'thinking', '.', 'The', 'next', 'time', 'I', 'see', 'a', 'slut', 'in', 'a', 'bar', 'who', 'looks', 'like', 'Diane', 'Lane', ',', 'I', "'m", 'running', 'the', 'other', 'way', '.', 'It', "'s", 'child', 'abuse', 'to', 'let', 'parents', 'that', 'worthless', 'raise', 'kids', '.', 'It', "'s", 'audience', 'abuse', 'to', 'simply', 'stick', 'Woodstock', 'and', 'the', 'moonlanding', 'into', 'a', 'flick', 'as', 'if', 'that', 'ipso', 'facto', 'means', 'the', 'film', 'portrays', '1969', '.'], 'label': 'neg'}
MAX_VOCAB_SIZE = 25_000
TEXT.build_vocab(train_data, max_size = MAX_VOCAB_SIZE)
LABEL.build_vocab(train_data)
print(f"Unique tokens in TEXT vocabulary: {len(TEXT.vocab)}")
print(f"Unique tokens in LABEL vocabulary: {len(LABEL.vocab)}")
Unique tokens in TEXT vocabulary: 25002 Unique tokens in LABEL vocabulary: 2
print(TEXT.vocab.freqs.most_common(20))
[('the', 203172), (',', 192039), ('.', 165889), ('a', 109265), ('and', 109192), ('of', 100241), ('to', 93511), ('is', 76322), ('in', 61299), ('I', 54013), ('it', 53609), ('that', 48928), ('"', 44101), ("'s", 43213), ('this', 42383), ('-', 36691), ('/><br', 35471), ('was', 34989), ('as', 30252), ('with', 30012)]
print(TEXT.vocab.itos[:10])
['<unk>', '<pad>', 'the', ',', '.', 'a', 'and', 'of', 'to', 'is']
print(LABEL.vocab.stoi)
defaultdict(None, {'neg': 0, 'pos': 1})
BATCH_SIZE = 64
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
(train_data, valid_data, test_data),
batch_size = BATCH_SIZE,
device = device)
/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/iterator.py:48: UserWarning: BucketIterator class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information. warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)
import torch.nn as nn
import torch.nn.functional as F
class FastText(nn.Module):
def __init__(self, vocab_size, embedding_dim, output_dim, pad_idx):
super().__init__()
self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=pad_idx)
self.fc = nn.Linear(embedding_dim, output_dim)
def forward(self, text):
#text = [sent len, batch size]
embedded = self.embedding(text)
#embedded = [sent len, batch size, emb dim]
embedded = embedded.permute(1, 0, 2)
#embedded = [batch size, sent len, emb dim]
pooled = F.avg_pool2d(embedded, (embedded.shape[1], 1)).squeeze(1)
#pooled = [batch size, embedding_dim]
return torch.sigmoid(self.fc(pooled))
INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
OUTPUT_DIM = 1
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]
model = FastText(INPUT_DIM, EMBEDDING_DIM, OUTPUT_DIM, PAD_IDX)
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'The model has {count_parameters(model):,} trainable parameters')
The model has 2,500,301 trainable parameters
UNK_IDX = TEXT.vocab.stoi[TEXT.unk_token]
model.embedding.weight.data[UNK_IDX] = torch.zeros(EMBEDDING_DIM)
model.embedding.weight.data[PAD_IDX] = torch.zeros(EMBEDDING_DIM)
import torch.optim as optim
optimizer = optim.Adam(model.parameters())
criterion = nn.BCELoss()
model = model.to(device)
criterion = criterion.to(device)
def binary_accuracy(preds, y):
"""
Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
"""
#round predictions to the closest integer
rounded_preds = torch.round(preds)
correct = (rounded_preds == y).float() #convert into float for division
acc = correct.sum() / len(correct)
return acc
def train(model, iterator, optimizer, criterion):
epoch_loss = 0
epoch_acc = 0
model.train()
for batch in iterator:
optimizer.zero_grad()
predictions = model(batch.text).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
def evaluate(model, iterator, criterion):
epoch_loss = 0
epoch_acc = 0
model.eval()
with torch.no_grad():
for batch in iterator:
predictions = model(batch.text).squeeze(1)
loss = criterion(predictions, batch.label)
acc = binary_accuracy(predictions, batch.label)
epoch_loss += loss.item()
epoch_acc += acc.item()
return epoch_loss / len(iterator), epoch_acc / len(iterator)
import time
def epoch_time(start_time, end_time):
elapsed_time = end_time - start_time
elapsed_mins = int(elapsed_time / 60)
elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
return elapsed_mins, elapsed_secs
N_EPOCHS = 3
best_valid_loss = float('inf')
for epoch in range(N_EPOCHS):
start_time = time.time()
train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
end_time = time.time()
epoch_mins, epoch_secs = epoch_time(start_time, end_time)
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss
torch.save(model.state_dict(), 'model.pt')
print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
print(f'\t Val. Loss: {valid_loss:.3f} | Val. Acc: {valid_acc*100:.2f}%')
/home/kuba/anaconda3/envs/tau/lib/python3.8/site-packages/torchtext/data/batch.py:23: UserWarning: Batch class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information. warnings.warn('{} class will be retired soon and moved to torchtext.legacy. Please see the most recent release notes for further information.'.format(self.__class__.__name__), UserWarning)
Epoch: 01 | Epoch Time: 0m 40s Train Loss: 0.686 | Train Acc: 59.60% Val. Loss: 0.630 | Val. Acc: 67.82% Epoch: 02 | Epoch Time: 0m 37s Train Loss: 0.639 | Train Acc: 74.52% Val. Loss: 0.502 | Val. Acc: 75.98%
model.load_state_dict(torch.load('model.pt'))
test_loss, test_acc = evaluate(model, test_iterator, criterion)
print(f'Test Loss: {test_loss:.3f} | Test Acc: {test_acc*100:.2f}%')
User Input
import spacy
nlp = spacy.load('en')
def predict_sentiment(model, sentence):
model.eval()
tokenized = [tok.text for tok in nlp.tokenizer(sentence)]
indexed = [TEXT.vocab.stoi[t] for t in tokenized]
tensor = torch.LongTensor(indexed).to(device)
tensor = tensor.unsqueeze(1)
prediction = model(tensor)
return prediction.item()
An example negative review...
predict_sentiment(model, "This film is terrible")
An example positive review...
predict_sentiment(model, "This film is great")