99 lines
2.7 KiB
Python
99 lines
2.7 KiB
Python
|
import pickle
|
||
|
from datasets import load_dataset
|
||
|
from transformers import AutoTokenizer, RobertaModel, RobertaTokenizer
|
||
|
from torch.utils.data import DataLoader
|
||
|
from transformers import AutoModelForSequenceClassification
|
||
|
from transformers import AdamW
|
||
|
from transformers import get_scheduler
|
||
|
import torch
|
||
|
from tqdm.auto import tqdm
|
||
|
|
||
|
BATCH_SIZE = 4
|
||
|
|
||
|
|
||
|
with open('train_dataset.pickle','rb') as f_p:
|
||
|
train_dataset = pickle.load(f_p)
|
||
|
|
||
|
with open('eval_dataset_small.pickle','rb') as f_p:
|
||
|
eval_dataset_small = pickle.load(f_p)
|
||
|
|
||
|
with open('eval_dataset_full.pickle','rb') as f_p:
|
||
|
eval_dataset_full = pickle.load(f_p)
|
||
|
|
||
|
|
||
|
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE)
|
||
|
eval_dataloader = DataLoader(eval_dataset_small, batch_size=BATCH_SIZE)
|
||
|
|
||
|
|
||
|
model = AutoModelForSequenceClassification.from_pretrained('roberta-base', num_labels=1)
|
||
|
optimizer = AdamW(model.parameters(), lr=1e-6)
|
||
|
|
||
|
|
||
|
num_epochs = 1
|
||
|
num_training_steps = num_epochs * len(train_dataloader)
|
||
|
lr_scheduler = get_scheduler(
|
||
|
"linear",
|
||
|
optimizer=optimizer,
|
||
|
num_warmup_steps=0,
|
||
|
num_training_steps=num_training_steps
|
||
|
)
|
||
|
|
||
|
|
||
|
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||
|
model.to(device)
|
||
|
|
||
|
|
||
|
progress_bar = tqdm(range(num_training_steps))
|
||
|
model.train()
|
||
|
|
||
|
model.train()
|
||
|
model.to(device)
|
||
|
|
||
|
def transform_batch(batch):
|
||
|
batch['input_ids'] = torch.stack(batch['input_ids']).permute(1,0).to(device)
|
||
|
batch['attention_mask'] = torch.stack(batch['attention_mask']).permute(1,0).to(device)
|
||
|
batch['labels'] = batch['year_scaled'].to(device).float()
|
||
|
|
||
|
batch['labels'].to(device)
|
||
|
batch['input_ids'].to(device)
|
||
|
batch['attention_mask'].to(device)
|
||
|
|
||
|
for c in set(batch.keys()) - {'input_ids', 'attention_mask', 'labels'}:
|
||
|
del batch[c]
|
||
|
return batch
|
||
|
|
||
|
|
||
|
def eval():
|
||
|
model.eval()
|
||
|
eval_loss = 0.0
|
||
|
for i, batch in enumerate(eval_dataloader):
|
||
|
batch = transform_batch(batch)
|
||
|
outputs = model(**batch)
|
||
|
loss = outputs.loss
|
||
|
eval_loss += loss.item()
|
||
|
print(f'eval loss: {eval_loss / i }')
|
||
|
model.train()
|
||
|
|
||
|
|
||
|
for epoch in range(num_epochs):
|
||
|
train_loss = 0.0
|
||
|
for i, batch in enumerate(train_dataloader):
|
||
|
batch = transform_batch(batch)
|
||
|
outputs = model(**batch)
|
||
|
loss = outputs.loss
|
||
|
loss.backward()
|
||
|
|
||
|
optimizer.step()
|
||
|
lr_scheduler.step()
|
||
|
optimizer.zero_grad()
|
||
|
progress_bar.update(1)
|
||
|
|
||
|
train_loss += loss.item()
|
||
|
#import pdb; pdb.set_trace()
|
||
|
if i % 5000 == 0 and i > 1 :
|
||
|
print(f'train loss: {train_loss / 5000 }', end = '\t\t')
|
||
|
train_loss = 0.0
|
||
|
eval()
|
||
|
|
||
|
model.save_pretrained('roberta_year_prediction')
|