ireland-news-headlines/main.py
Andrzej Preibisz 44398afd35 Solution 444465
2022-05-28 19:29:06 +02:00

140 lines
3.9 KiB
Python

import vowpalwabbit
import lzma
from gensim.parsing.preprocessing import remove_stopwords
VECTOR_SIZE = 100
def transform_to_vw_train_format(data_x, data_y):
formatted_data = []
for line, label in zip(data_x, data_y):
line_formatted = f"{label} | year_fraction:{line[0]} date:{line[1]} content:{line[2]}\n"
formatted_data.append(line_formatted)
return formatted_data
def transform_to_vw_test_format(data_x):
formatted_data = []
for line in data_x:
line_formatted = f"| year_fraction:{line[0]} date:{line[1]} content:{line[2]}\n"
formatted_data.append(line_formatted)
return formatted_data
def main():
train_x = []
train_y = []
test_a_x = []
test_b_x = []
dev_0_x = []
with lzma.open('train/in.tsv.xz', 'r') as f, lzma.open('train/expected.tsv.xz', 'r') as f2:
for line_in, line_expected in zip(f, f2):
line_in = line_in.decode("utf-8").strip().split('\t')
line_in[2] = remove_stopwords(line_in[2].lower())
train_x.append(line_in)
train_y.append(line_expected.strip().decode("utf-8"))
# breakpoint()
with open("test-A/in.tsv", 'r') as f:
for line_in in f.readlines():
line_in = line_in.strip().split('\t')
line_in[2] = remove_stopwords(line_in[2].lower())
test_a_x.append(line_in)
with open("test-B/in.tsv", 'r') as f:
for line_in in f.readlines():
line_in = line_in.strip().split('\t')
line_in[2] = remove_stopwords(line_in[2].lower())
test_b_x.append(line_in)
with open("dev-0/in.tsv", 'r') as f:
for line_in in f.readlines():
line_in = line_in.strip().split('\t')
line_in[2] = remove_stopwords(line_in[2].lower())
dev_0_x.append(line_in)
cat_labels = dict()
for i, item in enumerate(set(train_y), 1):
cat_labels[item] = i
train_y = [cat_labels.get(line) for line in train_y]
vw_format_train_data = transform_to_vw_train_format(train_x, train_y)
test_a_vw_format = transform_to_vw_test_format(test_a_x)
test_b_vw_format = transform_to_vw_test_format(test_b_x)
dev_0_vw_format = transform_to_vw_test_format(dev_0_x)
# print(vw_format_train_data[:10])
vw = vowpalwabbit.Workspace("--oaa 7 --ngram 3 --learning_rate=0.1")
for example in vw_format_train_data:
vw.learn(example)
predict_test_a = []
predict_test_b = []
predict_dev_0 = []
for line in dev_0_vw_format:
predict = vw.predict(line)
if round(predict) == 0:
predict = 1
predict_dev_0.append(round(predict))
predict_dev_0_text = []
for line in predict_dev_0:
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
predict_dev_0_text.append(label_key)
#TEST-A
for line in test_a_vw_format:
predict = vw.predict(line)
if round(predict) == 0:
predict = 1
predict_test_a.append(round(predict))
predict_test_a_text = []
for line in predict_test_a:
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
predict_test_a_text.append(label_key)
#TEST-B
for line in test_b_vw_format:
predict = vw.predict(line)
if round(predict) == 0:
predict = 1
predict_test_b.append(round(predict))
predict_test_b_text = []
for line in predict_test_b:
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
predict_test_b_text.append(label_key)
with open("dev-0/out.tsv", "w") as f:
for line in predict_dev_0_text:
f.write(f"{line}\n")
with open("test-A/out.tsv", "w") as f:
for line in predict_test_a_text:
f.write(f"{line}\n")
with open("test-B/out.tsv", "w") as f:
for line in predict_test_b_text:
f.write(f"{line}\n")
main()