from config import MODEL, TEST import pickle from datasets import load_dataset from transformers import AutoTokenizer, RobertaModel, RobertaTokenizer from torch.utils.data import DataLoader from transformers import AutoModelForSequenceClassification from transformers import AdamW from transformers import get_scheduler import torch from tqdm.auto import tqdm BATCH_SIZE = 24 EARLY_STOPPING = 3 WARMUP_STEPS = 10_000 STEPS_EVAL = 5_000 if TEST: STEPS_EVAL = 100 WARMUP_STEPS = 10 with open('train_dataset.pickle', 'rb') as f_p: train_dataset = pickle.load(f_p) with open('eval_dataset_small.pickle', 'rb') as f_p: eval_dataset_small = pickle.load(f_p) with open('eval_dataset_full.pickle', 'rb') as f_p: eval_dataset_full = pickle.load(f_p) train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=BATCH_SIZE) eval_dataloader_small = DataLoader(eval_dataset_small, batch_size=BATCH_SIZE) eval_dataloader_full = DataLoader(eval_dataset_full, batch_size=BATCH_SIZE) model = AutoModelForSequenceClassification.from_pretrained(MODEL, num_labels=1) optimizer = AdamW(model.parameters(), lr=1e-6) num_epochs = 15 num_training_steps = num_epochs * len(train_dataloader) lr_scheduler = get_scheduler( "linear", optimizer=optimizer, num_warmup_steps=WARMUP_STEPS, num_training_steps=num_training_steps ) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model.to(device) progress_bar = tqdm(range(num_training_steps)) model.train() model.train() model.to(device) def transform_batch(batch): batch['input_ids'] = torch.stack(batch['input_ids']).permute(1,0).to(device) batch['attention_mask'] = torch.stack(batch['attention_mask']).permute(1,0).to(device) batch['labels'] = batch['year_middle_float_scaled'].to(device).float() batch['labels'].to(device) batch['input_ids'].to(device) batch['attention_mask'].to(device) for c in set(batch.keys()) - {'input_ids', 'attention_mask', 'labels'}: del batch[c] return batch def eval(full=False): model.eval() eval_loss = 0.0 dataloader = eval_dataloader_full if full else eval_dataloader_small for i, batch in enumerate(dataloader): batch = transform_batch(batch) outputs = model(**batch) loss = outputs.loss eval_loss += loss.item() print(f'epoch {epoch} eval loss: {eval_loss / i }') model.train() return eval_loss best_eval_loss = 9999 epochs_without_progress = 0 for epoch in range(num_epochs): train_loss = 0.0 for i, batch in enumerate(train_dataloader): batch = transform_batch(batch) outputs = model(**batch) loss = outputs.loss loss.backward() train_loss += loss.item() progress_bar.update(1) # DELAYED UPDATE #if i % 16 == 1 and i > 1: # optimizer.step() # #lr_scheduler.step() # optimizer.zero_grad() # DELAYED UPDATE optimizer.step() lr_scheduler.step() optimizer.zero_grad() if i % STEPS_EVAL == 0 and i > 1: print(f' epoch {epoch} train loss: {train_loss / STEPS_EVAL }', end='\t\t') train_loss = 0.0 eval(full = False) model.save_pretrained(f'roberta_year_prediction/epoch_{epoch}') eval_loss = eval(full=True) if eval_loss < best_eval_loss: model.save_pretrained(f'roberta_year_prediction/epoch_best') best_eval_loss = eval_loss else: epochs_without_progress += 1 if epochs_without_progress > EARLY_STOPPING: break