2022-05-28 16:27:40 +02:00
|
|
|
import inout as io
|
|
|
|
|
|
|
|
def generateChooChoo(data, target, categories):
|
2022-05-28 17:57:38 +02:00
|
|
|
data = io.read(data)
|
|
|
|
years = [x[0] for x in data]
|
|
|
|
text = [x[2].replace('\n', '').replace(':', '') for x in data]
|
2022-05-28 16:27:40 +02:00
|
|
|
target = [x[0].replace('\n', '') for x in io.read(target)]
|
2022-05-28 17:57:38 +02:00
|
|
|
for i in range(len(text)):
|
|
|
|
data[i] = text[i] + ' year:'
|
2022-05-28 16:27:40 +02:00
|
|
|
|
|
|
|
if categories == {}:
|
|
|
|
i = 0
|
|
|
|
for x in target:
|
|
|
|
if x not in categories:
|
|
|
|
categories[x] = i
|
|
|
|
i += 1
|
|
|
|
|
|
|
|
return {'data': data, 'target': target}, categories
|
|
|
|
|
|
|
|
def predictFuture(test):
|
2022-05-28 17:57:38 +02:00
|
|
|
data = io.read(test)
|
|
|
|
years = [x[0] for x in data]
|
|
|
|
text = [x[2].replace('\n', '').replace(':', '') for x in data]
|
|
|
|
for i in range(len(text)):
|
|
|
|
data[i] = text[i] + ' year:'
|
2022-05-28 16:51:52 +02:00
|
|
|
with open('vw-' + test, 'w', encoding='utf-8') as f:
|
2022-05-28 16:27:40 +02:00
|
|
|
for text in data:
|
|
|
|
f.write('1 |text ' + text + '\n')
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
ireland_news_train, categories = generateChooChoo('train/in.tsv.xz', 'train/expected.tsv.xz', categories={})
|
|
|
|
ireland_news_dev, _ = generateChooChoo('dev-0/in.tsv', 'dev-0/expected.tsv', categories)
|
|
|
|
|
2022-05-28 16:51:52 +02:00
|
|
|
with open('vw-train', 'w', encoding='utf-8') as f:
|
2022-05-28 16:27:40 +02:00
|
|
|
for target, text in zip(ireland_news_train['target'], ireland_news_train['data']):
|
|
|
|
f.write(str(categories[target] + 1) + ' |text ' + text + '\n')
|
|
|
|
|
2022-05-28 16:51:52 +02:00
|
|
|
with open('vw-dev0', 'w', encoding='utf-8') as f, open('vw-dev0-targets', 'w', encoding='utf-8') as f_targets:
|
2022-05-28 16:27:40 +02:00
|
|
|
for target, text in zip(ireland_news_dev['target'], ireland_news_dev['data']):
|
|
|
|
f.write('1 |text ' + text + '\n')
|
|
|
|
f_targets.write(str(categories[target] + 1) + '\n')
|
|
|
|
|
|
|
|
predictFuture('test-A')
|
2022-05-28 16:44:54 +02:00
|
|
|
predictFuture('test-B')
|