Add decode.py
This commit is contained in:
parent
4ac818e536
commit
65d5471dc4
18
decode.py
Normal file
18
decode.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
import inout as io
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
target = [x[0].replace('\n', '') for x in io.read('train/expected.tsv.xz')]
|
||||||
|
categories = {}
|
||||||
|
i = 0
|
||||||
|
for x in target:
|
||||||
|
if x not in categories.values():
|
||||||
|
categories[i] = x
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
files = ['dev-0', 'test-A', 'test-B']
|
||||||
|
|
||||||
|
for file in files:
|
||||||
|
predicted = io.read(file)
|
||||||
|
predicted = [categories[round(float(x))] for x in predicted]
|
||||||
|
io.write(predicted, file + '/out.tsv')
|
2
inout.py
2
inout.py
@ -22,4 +22,4 @@ def read(dir):
|
|||||||
def write(output, dir):
|
def write(output, dir):
|
||||||
with open(dir, 'w', newline='', encoding='utf-8') as f:
|
with open(dir, 'w', newline='', encoding='utf-8') as f:
|
||||||
writer = csv.writer(f)
|
writer = csv.writer(f)
|
||||||
writer.writerows(output)
|
for row in output: writer.writerow([row])
|
Loading…
Reference in New Issue
Block a user