w123
This commit is contained in:
parent
756ef4277a
commit
dbadedfc1c
155
main.py
Normal file
155
main.py
Normal file
@ -0,0 +1,155 @@
|
|||||||
|
import pandas as pd
|
||||||
|
from transformers import BertTokenizer, BertForSequenceClassification
|
||||||
|
import torch
|
||||||
|
# from torchtext.data import BucketIterator, Iterator
|
||||||
|
|
||||||
|
train_input_path = "dev-0/in.tsv"
|
||||||
|
train_target_path = "dev-0/expected.tsv"
|
||||||
|
|
||||||
|
train_input = pd.read_csv(train_input_path, sep="\t")[:100]
|
||||||
|
train_input.columns=["text", "d"]
|
||||||
|
train_target = pd.read_csv(train_target_path, sep="\t")[:100]
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
device = torch.device("cuda")
|
||||||
|
|
||||||
|
MAX_SEQ_LEN = 128
|
||||||
|
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
|
||||||
|
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)
|
||||||
|
|
||||||
|
# label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
|
||||||
|
# text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=True, include_lengths=False, batch_first=True,
|
||||||
|
# fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_token=UNK_INDEX)
|
||||||
|
|
||||||
|
# fields = [('label', label_field), ('text', text_field),]
|
||||||
|
|
||||||
|
# valid_iter = BucketIterator(train_input["text"], batch_size=16, sort_key=lambda x: len(x.text),
|
||||||
|
# device=device, train=True, sort=True, sort_within_batch=True)
|
||||||
|
|
||||||
|
class BERT(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(BERT, self).__init__()
|
||||||
|
|
||||||
|
options_name = "bert-base-uncased"
|
||||||
|
self.encoder = BertForSequenceClassification.from_pretrained(options_name)
|
||||||
|
|
||||||
|
def forward(self, text, label):
|
||||||
|
loss, text_fea = self.encoder(text, labels=label)[:2]
|
||||||
|
|
||||||
|
return loss, text_fea
|
||||||
|
|
||||||
|
def save_checkpoint(save_path, model, valid_loss):
|
||||||
|
|
||||||
|
if save_path == None:
|
||||||
|
return
|
||||||
|
|
||||||
|
state_dict = {'model_state_dict': model.state_dict(),
|
||||||
|
'valid_loss': valid_loss}
|
||||||
|
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
|
print(f'Model saved to ==> {save_path}')
|
||||||
|
|
||||||
|
def load_checkpoint(load_path, model):
|
||||||
|
|
||||||
|
if load_path==None:
|
||||||
|
return
|
||||||
|
|
||||||
|
state_dict = torch.load(load_path, map_location=device)
|
||||||
|
print(f'Model loaded from <== {load_path}')
|
||||||
|
|
||||||
|
model.load_state_dict(state_dict['model_state_dict'])
|
||||||
|
return state_dict['valid_loss']
|
||||||
|
|
||||||
|
|
||||||
|
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
|
||||||
|
|
||||||
|
if save_path == None:
|
||||||
|
return
|
||||||
|
|
||||||
|
state_dict = {'train_loss_list': train_loss_list,
|
||||||
|
'valid_loss_list': valid_loss_list,
|
||||||
|
'global_steps_list': global_steps_list}
|
||||||
|
|
||||||
|
torch.save(state_dict, save_path)
|
||||||
|
print(f'Model saved to ==> {save_path}')
|
||||||
|
|
||||||
|
|
||||||
|
def load_metrics(load_path):
|
||||||
|
|
||||||
|
if load_path==None:
|
||||||
|
return
|
||||||
|
|
||||||
|
state_dict = torch.load(load_path, map_location=device)
|
||||||
|
print(f'Model loaded from <== {load_path}')
|
||||||
|
|
||||||
|
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']
|
||||||
|
|
||||||
|
def train(model,
|
||||||
|
optimizer,
|
||||||
|
criterion = torch.nn.BCELoss(),
|
||||||
|
train_data = train_input['text'],
|
||||||
|
train_target = train_target,
|
||||||
|
num_epochs = 5,
|
||||||
|
eval_every = len(train_input) // 2,
|
||||||
|
file_path = "./",
|
||||||
|
best_valid_loss = float("Inf")):
|
||||||
|
|
||||||
|
# initialize running values
|
||||||
|
running_loss = 0.0
|
||||||
|
valid_running_loss = 0.0
|
||||||
|
global_step = 0
|
||||||
|
train_loss_list = []
|
||||||
|
valid_loss_list = []
|
||||||
|
global_steps_list = []
|
||||||
|
|
||||||
|
# training loop
|
||||||
|
model.train()
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
for text, label in zip(train_data, train_target):
|
||||||
|
output = model(text, label)
|
||||||
|
loss, _ = output
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# update running values
|
||||||
|
running_loss += loss.item()
|
||||||
|
global_step += 1
|
||||||
|
|
||||||
|
# evaluation step
|
||||||
|
if global_step % eval_every == 0:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
average_train_loss = running_loss / eval_every
|
||||||
|
average_valid_loss = valid_running_loss / len(train_data)
|
||||||
|
train_loss_list.append(average_train_loss)
|
||||||
|
valid_loss_list.append(average_valid_loss)
|
||||||
|
global_steps_list.append(global_step)
|
||||||
|
|
||||||
|
# resetting running values
|
||||||
|
running_loss = 0.0
|
||||||
|
valid_running_loss = 0.0
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
# print progress
|
||||||
|
print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
|
||||||
|
.format(epoch+1, num_epochs, global_step, num_epochs*len(train_data),
|
||||||
|
average_train_loss, average_valid_loss))
|
||||||
|
|
||||||
|
# checkpoint
|
||||||
|
if best_valid_loss > average_valid_loss:
|
||||||
|
best_valid_loss = average_valid_loss
|
||||||
|
save_checkpoint(file_path + '/' + 'model.pt', model, best_valid_loss)
|
||||||
|
save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
|
||||||
|
|
||||||
|
save_metrics(file_path + '/' + 'metrics.pt', train_loss_list, valid_loss_list, global_steps_list)
|
||||||
|
print('Finished Training!')
|
||||||
|
|
||||||
|
model = BERT().to(device)
|
||||||
|
model.cuda()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)
|
||||||
|
|
||||||
|
train(model=model, optimizer=optimizer)
|
Loading…
Reference in New Issue
Block a user