
215 lines
5.7 KiB
Raw Normal View History

2021-06-20 19:42:14 +02:00
import pandas as pd
2021-06-20 22:05:07 +02:00
from transformers import BertTokenizer, AdamW, AutoModelForSequenceClassification
2021-06-20 19:42:14 +02:00
import torch
2021-06-20 22:05:07 +02:00
import matplotlib.pyplot as plt
from import TensorDataset, DataLoader, RandomSampler
import torch.nn as nn
from sklearn.utils.class_weight import compute_class_weight
import numpy as np
from model import BERT_Arch
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
train_input_path = "train/in.tsv"
train_target_path = "train/expected.tsv"
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
train_input = pd.read_csv(train_input_path, sep="\t")
2021-06-20 19:42:14 +02:00
train_input.columns=["text", "d"]
2021-06-20 22:05:07 +02:00
train_target = pd.read_csv(train_target_path, sep="\t")
2021-06-20 19:42:14 +02:00
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
device = torch.device("cuda")
2021-06-20 22:05:07 +02:00
# seq_len = [len(i.split()) for i in train_input["text"]]
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# pd.Series(seq_len).hist(bins = 30)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
bert = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
tokens_train = tokenizer.batch_encode_plus(
max_length = 25,
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
train_seq = torch.tensor(tokens_train['input_ids'])
train_mask = torch.tensor(tokens_train['attention_mask'])
train_y = torch.tensor(train_target.to_numpy())
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
#define a batch size
batch_size = 32
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# wrap tensors
train_data = TensorDataset(train_seq, train_mask, train_y)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# sampler for sampling the data during training
train_sampler = RandomSampler(train_data)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# dataLoader for train set
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
for param in bert.parameters():
param.requires_grad = False
model = BERT_Arch(bert)
model =
# model.cuda(0)
optimizer = AdamW(model.parameters(), lr = 1e-5)
class_weights = compute_class_weight('balanced', np.unique(train_target.to_numpy()), train_target['1'])
weights= torch.tensor(class_weights,dtype=torch.float)
weights =
# define the loss function
cross_entropy = nn.NLLLoss(weight=weights)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# number of training epochs
epochs = 10
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
def train():
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
total_loss, total_accuracy = 0, 0
# empty list to save model predictions
# iterate over batches
for step,batch in enumerate(train_dataloader):
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# progress update after every 50 batches.
if step % 50 == 0 and not step == 0:
print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))
# push the batch to gpu
batch = [ for r in batch]
sent_id, mask, labels = batch
# clear previously calculated gradients
# get model predictions for the current batch
preds = model(sent_id, mask)
# compute the loss between actual and predicted values
labels = torch.tensor([x[0] for x in labels]).to(device)
loss = cross_entropy(preds, labels)
# add on to the total loss
total_loss = total_loss + loss.item()
# backward pass to calculate the gradients
# clip the the gradients to 1.0. It helps in preventing the exploding gradient problem
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# update parameters
# model predictions are stored on GPU. So, push it to CPU
# append the model predictions
# compute the training loss of the epoch
avg_loss = total_loss / len(train_dataloader)
# predictions are in the form of (no. of batches, size of batch, no. of classes).
# reshape the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
#returns the loss and predictions
return avg_loss, total_preds
def evaluate():
# deactivate dropout layers
total_loss, total_accuracy = 0, 0
# empty list to save the model predictions
total_preds = []
# iterate over batches
for step,batch in enumerate(train_dataloader):
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# Progress update every 50 batches.
if step % 50 == 0 and not step == 0:
# Calculate elapsed time in minutes.
# Report progress.
print(' Batch {:>5,} of {:>5,}.'.format(step, len(train_dataloader)))
# push the batch to gpu
batch = [ for t in batch]
sent_id, mask, labels = batch
# deactivate autograd
with torch.no_grad():
# model predictions
preds = model(sent_id, mask)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# compute the validation loss between actual and predicted values
labels = torch.tensor([x[0] for x in labels]).to(device)
loss = cross_entropy(preds,labels)
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
total_loss = total_loss + loss.item()
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
preds = preds.detach().cpu().numpy()
# compute the validation loss of the epoch
avg_loss = total_loss / len(train_dataloader)
# reshape the predictions in form of (number of samples, no. of classes)
total_preds = np.concatenate(total_preds, axis=0)
return avg_loss, total_preds
# avg_loss, total_preds = train()
# set initial loss to infinite
best_valid_loss = float('inf')
# empty lists to store training and validation loss of each epoch
print("Started training!")
#for each epoch
for epoch in range(epochs):
print('\n Epoch {:} / {:}'.format(epoch + 1, epochs))
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
#train model
train_loss, _ = train()
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
#evaluate model
valid_loss, _ = evaluate()
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
#save the best model
if valid_loss < best_valid_loss:
best_valid_loss = valid_loss, '')
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
# append training and validation loss
print(f'\nTraining Loss: {train_loss:.3f}')
print(f'Validation Loss: {valid_loss:.3f}')
2021-06-20 19:42:14 +02:00
2021-06-20 22:05:07 +02:00
print("Finished !!!")