Add decode.py

This commit is contained in:
Iwona Christop 2022-05-28 17:05:22 +02:00
parent 4ac818e536
commit 65d5471dc4
2 changed files with 19 additions and 1 deletions

18
decode.py Normal file
View File

@ -0,0 +1,18 @@
import inout as io
if __name__ == '__main__':
target = [x[0].replace('\n', '') for x in io.read('train/expected.tsv.xz')]
categories = {}
i = 0
for x in target:
if x not in categories.values():
categories[i] = x
i += 1
files = ['dev-0', 'test-A', 'test-B']
for file in files:
predicted = io.read(file)
predicted = [categories[round(float(x))] for x in predicted]
io.write(predicted, file + '/out.tsv')

View File

@ -22,4 +22,4 @@ def read(dir):
def write(output, dir):
with open(dir, 'w', newline='', encoding='utf-8') as f:
writer = csv.writer(f)
writer.writerows(output)
for row in output: writer.writerow([row])