diff --git a/decode.py b/decode.py new file mode 100644 index 0000000..5401926 --- /dev/null +++ b/decode.py @@ -0,0 +1,18 @@ +import inout as io + + +if __name__ == '__main__': + target = [x[0].replace('\n', '') for x in io.read('train/expected.tsv.xz')] + categories = {} + i = 0 + for x in target: + if x not in categories.values(): + categories[i] = x + i += 1 + + files = ['dev-0', 'test-A', 'test-B'] + + for file in files: + predicted = io.read(file) + predicted = [categories[round(float(x))] for x in predicted] + io.write(predicted, file + '/out.tsv') diff --git a/inout.py b/inout.py index b25f678..47e1073 100644 --- a/inout.py +++ b/inout.py @@ -22,4 +22,4 @@ def read(dir): def write(output, dir): with open(dir, 'w', newline='', encoding='utf-8') as f: writer = csv.writer(f) - writer.writerows(output) \ No newline at end of file + for row in output: writer.writerow([row]) \ No newline at end of file