140 lines
3.9 KiB
Python
140 lines
3.9 KiB
Python
import vowpalwabbit
|
|
import lzma
|
|
from gensim.parsing.preprocessing import remove_stopwords
|
|
|
|
|
|
VECTOR_SIZE = 100
|
|
|
|
|
|
def transform_to_vw_train_format(data_x, data_y):
|
|
formatted_data = []
|
|
for line, label in zip(data_x, data_y):
|
|
line_formatted = f"{label} | year_fraction:{line[0]} date:{line[1]} content:{line[2]}\n"
|
|
formatted_data.append(line_formatted)
|
|
|
|
return formatted_data
|
|
|
|
|
|
def transform_to_vw_test_format(data_x):
|
|
formatted_data = []
|
|
for line in data_x:
|
|
line_formatted = f"| year_fraction:{line[0]} date:{line[1]} content:{line[2]}\n"
|
|
formatted_data.append(line_formatted)
|
|
|
|
return formatted_data
|
|
|
|
|
|
def main():
|
|
train_x = []
|
|
train_y = []
|
|
|
|
test_a_x = []
|
|
test_b_x = []
|
|
dev_0_x = []
|
|
|
|
with lzma.open('train/in.tsv.xz', 'r') as f, lzma.open('train/expected.tsv.xz', 'r') as f2:
|
|
for line_in, line_expected in zip(f, f2):
|
|
|
|
line_in = line_in.decode("utf-8").strip().split('\t')
|
|
line_in[2] = remove_stopwords(line_in[2].lower())
|
|
train_x.append(line_in)
|
|
|
|
train_y.append(line_expected.strip().decode("utf-8"))
|
|
|
|
# breakpoint()
|
|
with open("test-A/in.tsv", 'r') as f:
|
|
for line_in in f.readlines():
|
|
|
|
line_in = line_in.strip().split('\t')
|
|
line_in[2] = remove_stopwords(line_in[2].lower())
|
|
test_a_x.append(line_in)
|
|
|
|
with open("test-B/in.tsv", 'r') as f:
|
|
for line_in in f.readlines():
|
|
|
|
line_in = line_in.strip().split('\t')
|
|
line_in[2] = remove_stopwords(line_in[2].lower())
|
|
test_b_x.append(line_in)
|
|
|
|
with open("dev-0/in.tsv", 'r') as f:
|
|
for line_in in f.readlines():
|
|
|
|
line_in = line_in.strip().split('\t')
|
|
line_in[2] = remove_stopwords(line_in[2].lower())
|
|
|
|
dev_0_x.append(line_in)
|
|
|
|
cat_labels = dict()
|
|
for i, item in enumerate(set(train_y), 1):
|
|
cat_labels[item] = i
|
|
|
|
train_y = [cat_labels.get(line) for line in train_y]
|
|
|
|
vw_format_train_data = transform_to_vw_train_format(train_x, train_y)
|
|
test_a_vw_format = transform_to_vw_test_format(test_a_x)
|
|
test_b_vw_format = transform_to_vw_test_format(test_b_x)
|
|
dev_0_vw_format = transform_to_vw_test_format(dev_0_x)
|
|
# print(vw_format_train_data[:10])
|
|
|
|
vw = vowpalwabbit.Workspace("--oaa 7 --ngram 3 --learning_rate=0.1")
|
|
|
|
for example in vw_format_train_data:
|
|
vw.learn(example)
|
|
|
|
predict_test_a = []
|
|
predict_test_b = []
|
|
predict_dev_0 = []
|
|
|
|
for line in dev_0_vw_format:
|
|
predict = vw.predict(line)
|
|
if round(predict) == 0:
|
|
predict = 1
|
|
predict_dev_0.append(round(predict))
|
|
|
|
predict_dev_0_text = []
|
|
|
|
for line in predict_dev_0:
|
|
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
|
|
predict_dev_0_text.append(label_key)
|
|
|
|
#TEST-A
|
|
for line in test_a_vw_format:
|
|
predict = vw.predict(line)
|
|
if round(predict) == 0:
|
|
predict = 1
|
|
predict_test_a.append(round(predict))
|
|
|
|
predict_test_a_text = []
|
|
|
|
for line in predict_test_a:
|
|
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
|
|
predict_test_a_text.append(label_key)
|
|
|
|
#TEST-B
|
|
for line in test_b_vw_format:
|
|
predict = vw.predict(line)
|
|
if round(predict) == 0:
|
|
predict = 1
|
|
predict_test_b.append(round(predict))
|
|
|
|
predict_test_b_text = []
|
|
|
|
for line in predict_test_b:
|
|
label_key = list(cat_labels.keys())[list(cat_labels.values()).index(line)]
|
|
predict_test_b_text.append(label_key)
|
|
|
|
with open("dev-0/out.tsv", "w") as f:
|
|
for line in predict_dev_0_text:
|
|
f.write(f"{line}\n")
|
|
|
|
with open("test-A/out.tsv", "w") as f:
|
|
for line in predict_test_a_text:
|
|
f.write(f"{line}\n")
|
|
|
|
with open("test-B/out.tsv", "w") as f:
|
|
for line in predict_test_b_text:
|
|
f.write(f"{line}\n")
|
|
|
|
|
|
main()
|