diff --git a/main.py b/main.py new file mode 100644 index 0000000..c49bcd9 --- /dev/null +++ b/main.py @@ -0,0 +1,155 @@ +import pandas as pd +from transformers import BertTokenizer, BertForSequenceClassification +import torch +# from torchtext.data import BucketIterator, Iterator + +train_input_path = "dev-0/in.tsv" +train_target_path = "dev-0/expected.tsv" + +train_input = pd.read_csv(train_input_path, sep="\t")[:100] +train_input.columns=["text", "d"] +train_target = pd.read_csv(train_target_path, sep="\t")[:100] + +tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') +device = torch.device("cuda") + +MAX_SEQ_LEN = 128 +PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token) +UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token) + +# label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float) +# text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=True, include_lengths=False, batch_first=True, +# fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX) + +# fields = [('label', label_field), ('text', text_field),] + +# valid_iter = BucketIterator(train_input["text"], batch_size=16, sort_key=lambda x: len(x.text), +# device=device, train=True, sort=True, sort_within_batch=True) + +class BERT(torch.nn.Module): + + def __init__(self): + super(BERT, self).__init__() + + options_name = "bert-base-uncased" + self.encoder = BertForSequenceClassification.from_pretrained(options_name) + + def forward(self, text, label): + loss, text_fea = self.encoder(text, labels=label)[:2] + + return loss, text_fea + +def save_checkpoint(save_path, model, valid_loss): + + if save_path == None: + return + + state_dict = {'model_state_dict': model.state_dict(), + 'valid_loss': valid_loss} + + torch.save(state_dict, save_path) + print(f'Model saved to ==> {save_path}') + +def load_checkpoint(load_path, model): + + if load_path==None: + return + + state_dict = torch.load(load_path, map_location=device) + print(f'Model loaded from <== {load_path}') + + model.load_state_dict(state_dict['model_state_dict']) + return state_dict['valid_loss'] + + +def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list): + + if save_path == None: + return + + state_dict = {'train_loss_list': train_loss_list, + 'valid_loss_list': valid_loss_list, + 'global_steps_list': global_steps_list} + + torch.save(state_dict, save_path) + print(f'Model saved to ==> {save_path}') + + +def load_metrics(load_path): + + if load_path==None: + return + + state_dict = torch.load(load_path, map_location=device) + print(f'Model loaded from <== {load_path}') + + return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list'] + +def train(model, + optimizer, + criterion = torch.nn.BCELoss(), + train_data = train_input['text'], + train_target = train_target, + num_epochs = 5, + eval_every = len(train_input) // 2, + file_path = "./", + best_valid_loss = float("Inf")): + + # initialize running values + running_loss = 0.0 + valid_running_loss = 0.0 + global_step = 0 + train_loss_list = [] + valid_loss_list = [] + global_steps_list = [] + + # training loop + model.train() + for epoch in range(num_epochs): + for text, label in zip(train_data, train_target): + output = model(text, label) + loss, _ = output + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + # update running values + running_loss += loss.item() + global_step += 1 + + # evaluation step + if global_step % eval_every == 0: + model.eval() + + # evaluation + average_train_loss = running_loss / eval_every + average_valid_loss = valid_running_loss / len(train_data) + train_loss_list.append(average_train_loss) + valid_loss_list.append(average_valid_loss) + global_steps_list.append(global_step) + + # resetting running values + running_loss = 0.0 + valid_running_loss = 0.0 + model.train() + + # print progress + print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}' + .format(epoch+1, num_epochs, global_step, num_epochs*len(train_data), + average_train_loss, average_valid_loss)) + + # checkpoint + if best_valid_loss > average_valid_loss: + best_valid_loss = average_valid_loss + save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss) + save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list) + + save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list) + print('Finished Training!') + +model = BERT().to(device) +model.cuda() +optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) + +train(model=model, optimizer=optimizer) \ No newline at end of file