Add transformers fine-tuning
This commit is contained in:
parent
a489085007
commit
aed2e01f68
5272
dev-0/out.tsv
5272
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
56
fine_tuning.py
Normal file
56
fine_tuning.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
|
||||||
|
import random
|
||||||
|
import torch
|
||||||
|
|
||||||
|
with open('train/in.tsv') as f:
|
||||||
|
data_train_X = f.readlines()
|
||||||
|
|
||||||
|
with open('train/expected.tsv') as f:
|
||||||
|
data_train_Y = f.readlines()
|
||||||
|
|
||||||
|
with open('dev-0/in.tsv') as f:
|
||||||
|
data_dev_X = f.readlines()
|
||||||
|
|
||||||
|
with open('test-A/in.tsv') as f:
|
||||||
|
data_test_X = f.readlines()
|
||||||
|
|
||||||
|
class CustomDataset(torch.utils.data.Dataset):
|
||||||
|
def __init__(self, encodings, labels):
|
||||||
|
self.encodings = encodings
|
||||||
|
self.labels = labels
|
||||||
|
|
||||||
|
def __getitem__(self, idx):
|
||||||
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
|
||||||
|
item['labels'] = torch.tensor(self.labels[idx])
|
||||||
|
return item
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.labels)
|
||||||
|
|
||||||
|
data_train = list(zip(data_train_X, data_train_Y))
|
||||||
|
data_train = random.sample(data_train, 150000)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
|
||||||
|
train_encodings = tokenizer([text[0] for text in data_train], truncation=True, padding=True)
|
||||||
|
|
||||||
|
train_dataset = CustomDataset(train_encodings, [int(text[1]) for text in data_train])
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=2)
|
||||||
|
|
||||||
|
training_args = TrainingArguments("test_trainer")
|
||||||
|
trainer = Trainer(
|
||||||
|
model=model, args=training_args, train_dataset=train_dataset)
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
with open('train/out.tsv', 'w') as writer:
|
||||||
|
for result in trainer.predict(data_train_X):
|
||||||
|
writer.write(str(result) + '\n')
|
||||||
|
|
||||||
|
with open('dev-0/out.tsv', 'w') as writer:
|
||||||
|
for result in trainer.predict(data_dev_X):
|
||||||
|
writer.write(str(result) + '\n')
|
||||||
|
|
||||||
|
with open('test-A/out.tsv', 'w') as writer:
|
||||||
|
for result in trainer.predict(data_test_X):
|
||||||
|
writer.write(str(result) + '\n')
|
||||||
|
|
5152
test-A/out.tsv
5152
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
289579
train/out.tsv
289579
train/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user