s478855 - correct outs

This commit is contained in:
ulaniuk 2022-06-22 13:32:19 +02:00
parent 67f4a46156
commit bd06f0ed0c
4 changed files with 3670 additions and 3700 deletions

File diff suppressed because it is too large Load Diff

2305
run2.ipynb

File diff suppressed because it is too large Load Diff

11
run2.py
View File

@ -138,8 +138,14 @@ testA_dataset = testA_dataset.map(tokenize, batched=True, batch_size=len(train_d
dev_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
testA_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
y_pred_dev = trainer.predict(dev_dataset).label_ids
y_pred_test = trainer.predict(testA_dataset).label_ids
y_pred_dev = trainer.predict(dev_dataset).predictions
y_pred_test = trainer.predict(testA_dataset).predictions
def get_labels(predictions):
return [0 if a > b else 1 for a, b in predictions]
y_pred_dev = get_labels(y_pred_dev)
y_pred_test = get_labels(y_pred_test)
with open('/content/drive/MyDrive/eks/dev-0/out.tsv', 'wt') as f:
for pred in y_pred_dev:
@ -148,3 +154,4 @@ with open('/content/drive/MyDrive/eks/dev-0/out.tsv', 'wt') as f:
with open('/content/drive/MyDrive/eks/test-A/out.tsv', 'wt') as f:
for pred in y_pred_test:
f.write(str(pred)+'\n')

File diff suppressed because it is too large Load Diff