add bert script
This commit is contained in:
parent
69719655e6
commit
305ae96fda
6
.ipynb_checkpoints/Untitled-checkpoint.ipynb
Normal file
6
.ipynb_checkpoints/Untitled-checkpoint.ipynb
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"cells": [],
|
||||||
|
"metadata": {},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
6
.ipynb_checkpoints/Untitled1-checkpoint.ipynb
Normal file
6
.ipynb_checkpoints/Untitled1-checkpoint.ipynb
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
{
|
||||||
|
"cells": [],
|
||||||
|
"metadata": {},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
72
bert.py
Normal file
72
bert.py
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
PATHS = ['train/in.tsv', 'train/expected.tsv', 'dev-0/in.tsv', 'test-A/in.tsv', './dev-0/out.tsv', './test-A/out.tsv']
|
||||||
|
PRE_TRAINED = 'roberta-base'
|
||||||
|
|
||||||
|
def get_data(path):
|
||||||
|
data = []
|
||||||
|
with open(path, encoding='utf-8') as f:
|
||||||
|
data = f.readlines()
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
class IMDbDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, encodings, labels):
|
||||||
|
self.encodings = encodings
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||||
|
item['labels'] = torch.tensor(self.labels[idx])
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
def prepare(data_train_X, data_train_Y):
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED)
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(PRE_TRAINED, num_labels=2)
|
||||||
|
encoded_input = tokenizer([text[0] for text in list(zip(data_train_X, data_train_Y))], truncation=True, padding=True)
|
||||||
|
train_dataset = IMDbDataset(encoded_input , [int(text[1]) for text in list(zip(data_train_X, data_train_Y))])
|
||||||
|
|
||||||
|
return train_dataset, model
|
||||||
|
|
||||||
|
|
||||||
|
def trainer(train_dataset, model):
|
||||||
|
training_args = TrainingArguments(
|
||||||
|
output_dir='./results', # output directory
|
||||||
|
num_train_epochs=3, # total number of training epochs
|
||||||
|
per_device_train_batch_size=16, # batch size per device during training
|
||||||
|
per_device_eval_batch_size=64, # batch size for evaluation
|
||||||
|
warmup_steps=500, # number of warmup steps for learning rate scheduler
|
||||||
|
weight_decay=0.01, # strength of weight decay
|
||||||
|
logging_dir='./logs', # directory for storing logs
|
||||||
|
logging_steps=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model, # the instantiated Transformers model to be trained
|
||||||
|
args=training_args, # training arguments, defined above
|
||||||
|
train_dataset=train_dataset, # training dataset
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
#data
|
||||||
|
X_train = get_data(PATHS[0])
|
||||||
|
y_train = get_data(PATHS[1])
|
||||||
|
X_dev = get_data(PATHS[2])
|
||||||
|
X_test = get_data(PATHS[3])
|
||||||
|
|
||||||
|
#prepare
|
||||||
|
train_dataset, model = prepare(X_train, y_train)
|
||||||
|
|
||||||
|
#trainer
|
||||||
|
trainer(train_dataset, model)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
logs/events.out.tfevents.1624271570.POZ-PIOTRU.20712.0
Normal file
BIN
logs/events.out.tfevents.1624271570.POZ-PIOTRU.20712.0
Normal file
Binary file not shown.
BIN
logs/events.out.tfevents.1624272912.POZ-PIOTRU.4624.0
Normal file
BIN
logs/events.out.tfevents.1624272912.POZ-PIOTRU.4624.0
Normal file
Binary file not shown.
BIN
logs/events.out.tfevents.1624273023.POZ-PIOTRU.19528.0
Normal file
BIN
logs/events.out.tfevents.1624273023.POZ-PIOTRU.19528.0
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue
Block a user