s478855 - correct outs
This commit is contained in:
parent
67f4a46156
commit
bd06f0ed0c
1460
dev-0/out.tsv
1460
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
2305
run2.ipynb
2305
run2.ipynb
File diff suppressed because it is too large
Load Diff
13
run2.py
13
run2.py
@ -138,8 +138,14 @@ testA_dataset = testA_dataset.map(tokenize, batched=True, batch_size=len(train_d
|
||||
dev_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
testA_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'label'])
|
||||
|
||||
y_pred_dev = trainer.predict(dev_dataset).label_ids
|
||||
y_pred_test = trainer.predict(testA_dataset).label_ids
|
||||
y_pred_dev = trainer.predict(dev_dataset).predictions
|
||||
y_pred_test = trainer.predict(testA_dataset).predictions
|
||||
|
||||
def get_labels(predictions):
|
||||
return [0 if a > b else 1 for a, b in predictions]
|
||||
|
||||
y_pred_dev = get_labels(y_pred_dev)
|
||||
y_pred_test = get_labels(y_pred_test)
|
||||
|
||||
with open('/content/drive/MyDrive/eks/dev-0/out.tsv', 'wt') as f:
|
||||
for pred in y_pred_dev:
|
||||
@ -147,4 +153,5 @@ with open('/content/drive/MyDrive/eks/dev-0/out.tsv', 'wt') as f:
|
||||
|
||||
with open('/content/drive/MyDrive/eks/test-A/out.tsv', 'wt') as f:
|
||||
for pred in y_pred_test:
|
||||
f.write(str(pred)+'\n')
|
||||
f.write(str(pred)+'\n')
|
||||
|
||||
|
3592
test-A/out.tsv
3592
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user