from config import * import pickle from datasets import load_dataset from transformers import AutoTokenizer, RobertaModel, RobertaTokenizer from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification from torch.optim import Adam from transformers import get_scheduler import torch from tqdm.auto import tqdm import os import pickle from regressor_head import RegressorHead from classification_head import YearClassificationHead try: os.mkdir('roberta_year_prediction') except Exception: pass def pickle_model_save(name): with open(f'roberta_year_prediction/{name}', 'wb') as f: pickle.dump(model,f) if TEST: STEPS_EVAL = 10 WARMUP_STEPS = 10 with open('train_dataset.pickle','rb') as f_p: train_dataset = pickle.load(f_p) with open('eval_dataset_small.pickle','rb') as f_p: eval_dataset_small = pickle.load(f_p) with open('eval_dataset_full.pickle','rb') as f_p: eval_dataset_full = pickle.load(f_p) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE) eval_dataloader_small = DataLoader(eval_dataset_small, batch_size=20) eval_dataloader_full = DataLoader(eval_dataset_full, batch_size=BATCH_SIZE) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = RobertaModel.from_pretrained('roberta-base') #model = RobertaModel(model.config) model.regressor_head = YearClassificationHead(768, MIN_YEAR, MAX_YEAR).to('cuda') model.to(device) optimizer = Adam(model.parameters(), lr=LR) num_training_steps = NUM_EPOCHS * len(train_dataloader) #lr_scheduler = get_scheduler( # "linear", # optimizer=optimizer, # num_warmup_steps=WARMUP_STEPS, # num_training_steps=num_training_steps #) #progress_bar = tqdm(range(num_training_steps)) model.train() model.train() model.to(device) def transform_batch(batch): batch['input_ids'] = torch.stack(batch['input_ids']).permute(1,0).to(device) batch['attention_mask'] = torch.stack(batch['attention_mask']).permute(1,0).to(device) labels = batch['year'].to(device) batch['input_ids'].to(device) batch['attention_mask'].to(device) for c in set(batch.keys()) - {'input_ids', 'attention_mask'}: del batch[c] return batch, labels def eval(full = False): model.eval() with torch.no_grad(): eval_loss = 0.0 dataloader = eval_dataloader_full if full else eval_dataloader_small items_passed = 0 for i, batch in enumerate(dataloader): items_passed += len(batch) batch, labels = transform_batch(batch) outputs = model(**batch)[0] outputs = model.regressor_head(outputs) loss = criterion(outputs, labels) eval_loss += loss.item() eval_loss = (eval_loss / items_passed) print(f'eval loss full={full}: {eval_loss:.5f}', end = '\n') model.train() return eval_loss #criterion = torch.nn.MSELoss(reduction='sum').to(device) criterion = torch.nn.CrossEntropyLoss(reduction='sum').to(device) best_eval_loss = 9999 epochs_without_progress = 0 eval() for epoch in range(NUM_EPOCHS): train_loss = 0.0 items_passed = 0 for i, batch in enumerate(train_dataloader): items_passed += len(batch) batch, labels = transform_batch(batch) outputs = model(**batch)[0] outputs = model.regressor_head(outputs) loss = criterion(outputs, labels) loss.backward() train_loss += loss.item() #progress_bar.update(1) optimizer.step() #lr_scheduler.step() optimizer.zero_grad() model.zero_grad() if i % STEPS_EVAL == 0 and i > 1 : print(f' epoch {epoch} train loss: {(train_loss / items_passed):.5f}', end = '\t') items_passed = 0 train_loss = 0.0 eval(full = False) eval_loss = eval(full=True) pickle_model_save(f'epoch_{epoch}') pickle_model_save(f'epoch_last') if eval_loss < best_eval_loss: pickle_model_save(f'epoch_best') print('\nsaving best model') best_eval_loss = eval_loss else: epochs_without_progress += 1 print(f'epochs_witohut_progress: {epochs_without_progress}') if epochs_without_progress > EARLY_STOPPING: print('early stopping') break print(f'best_eval_loss: {best_eval_loss:5f}', end = '\n')