add output

This commit is contained in:
s440058 2021-06-21 13:15:35 +02:00
parent 8f410ae809
commit fb4b0d95e3

View File

@ -2,6 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trai
import torch import torch
PATHS = ['train/in.tsv', 'train/expected.tsv', 'dev-0/in.tsv', 'test-A/in.tsv', './dev-0/out.tsv', './test-A/out.tsv'] PATHS = ['train/in.tsv', 'train/expected.tsv', 'dev-0/in.tsv', 'test-A/in.tsv', './dev-0/out.tsv', './test-A/out.tsv']
OUTPUT_PATHS = ['dev-0/out.tsv', 'test-A/out.tsv']
PRE_TRAINED = 'roberta-base' PRE_TRAINED = 'roberta-base'
def get_data(path): def get_data(path):
@ -11,6 +12,12 @@ def get_data(path):
return data return data
def generate_output(path, trainer, X_data):
data = []
with open(path, encoding='utf-8') as f:
for result in trainer.predict(X_data):
f.write(str(result) + '\n')
class IMDbDataset(torch.utils.data.Dataset): class IMDbDataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels): def __init__(self, encodings, labels):
self.encodings = encodings self.encodings = encodings
@ -32,7 +39,7 @@ def prepare(data_train_X, data_train_Y):
return train_dataset, model return train_dataset, model
def trainer(train_dataset, model): def training(train_dataset, model):
training_args = TrainingArguments( training_args = TrainingArguments(
output_dir='./results', # output directory output_dir='./results', # output directory
num_train_epochs=3, # total number of training epochs num_train_epochs=3, # total number of training epochs
@ -52,6 +59,8 @@ def trainer(train_dataset, model):
trainer.train() trainer.train()
return trainer
def main(): def main():
#data #data
X_train = get_data(PATHS[0]) X_train = get_data(PATHS[0])
@ -63,7 +72,11 @@ def main():
train_dataset, model = prepare(X_train, y_train) train_dataset, model = prepare(X_train, y_train)
#trainer #trainer
trainer(train_dataset, model) trainer = training(train_dataset, model)
#output
generate_output(OUTPUT_PATHS[0], trainer, X_dev)
generate_output(OUTPUT_PATHS[1], trainer, X_test)
if __name__ == '__main__': if __name__ == '__main__':
main() main()