2022-05-28 17:05:22 +02:00
|
|
|
import inout as io
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
target = [x[0].replace('\n', '') for x in io.read('train/expected.tsv.xz')]
|
|
|
|
categories = {}
|
|
|
|
i = 0
|
|
|
|
for x in target:
|
|
|
|
if x not in categories.values():
|
|
|
|
categories[i] = x
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
files = ['dev-0', 'test-A', 'test-B']
|
|
|
|
|
|
|
|
for file in files:
|
2022-05-28 17:39:32 +02:00
|
|
|
predicted = io.read('predicted-' + file)
|
|
|
|
predicted = [1 if float(x)-1 <= -0.5 else x for x in predicted]
|
|
|
|
predicted = [categories[round(float(x))-1] for x in predicted]
|
2022-05-28 17:05:22 +02:00
|
|
|
io.write(predicted, file + '/out.tsv')
|