mt-summit-corpora/inject.py
2022-01-18 10:27:53 +01:00

137 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import pandas as pd
import spacy
from spaczz.matcher import FuzzyMatcher
# spacy.require_gpu()
spacy_nlp_en = spacy.load('en_core_web_sm')
spacy_nlp_pl = spacy.load('pl_core_news_sm')
print('lemmatizing glossary')
glossary = pd.read_csv('glossary.tsv', sep='\t', header=None, names=['source', 'result'])
source_lemmatized = []
for word in glossary['source']:
temp = []
for token in spacy_nlp_en(word):
temp.append(token.lemma_)
source_lemmatized.append(' '.join(temp).replace(' - ', '-').replace(' ', '').replace(' / ', '/').replace(' ( ', '(').replace(' ) ', ')'))
result_lemmatized = []
for word in glossary['result']:
temp = []
for token in spacy_nlp_pl(word):
temp.append(token.lemma_)
result_lemmatized.append(' '.join(temp).replace(' - ', '-').replace(' ', '').replace(' / ', '/').replace(' ( ', '(').replace(' ) ', ')'))
glossary['source_lem'] = source_lemmatized
glossary['result_lem'] = result_lemmatized
glossary = glossary[['source', 'source_lem', 'result', 'result_lem']]
glossary.set_index('source_lem')
glossary.to_csv('glossary_lem.tsv', sep='\t')
dev_path = 'dev-0/'
print('lemmatizing corpus ' + dev_path)
skip_chars = ''',./!?'''
with open(dev_path + 'in.tsv', 'r') as file:
file_lemmatized = []
for line in file:
temp = []
for token in spacy_nlp_en(line):
temp.append(token.lemma_)
file_lemmatized.append(' '.join([x for x in temp if x not in skip_chars])
.replace(' - ', '-').replace(' ', '').replace(' / ', '/').replace(' ( ', '(').replace(' ) ', ')'))
with open(dev_path + 'expected.tsv', 'r') as file:
file_pl_lemmatized = []
for line in file:
temp = []
for token in spacy_nlp_pl(line):
temp.append(token.lemma_)
file_pl_lemmatized.append(' '.join([x for x in temp if x not in skip_chars])
.replace(' - ', '-').replace(' ', '').replace(' / ', '/').replace(' ( ', '(').replace(' ) ', ')'))
# glossary
glossary = pd.read_csv('glossary_lem.tsv', sep='\t', header=0, index_col=0)
train_glossary = glossary.iloc[[x for x in range(len(glossary)) if x % 6 != 0]]
# add rules to English matcher
nlp = spacy.blank("en")
matcher = FuzzyMatcher(nlp.vocab)
for word in train_glossary['source_lem']:
matcher.add(word, [nlp(word)])
# add rules to Polish matcher
nlp_pl = spacy.blank("pl")
matcher_pl = FuzzyMatcher(nlp_pl.vocab)
for word, word_id in zip(train_glossary['result_lem'], train_glossary['source_lem']):
matcher_pl.add(word, [nlp_pl(word)])
en = []
translation_line_counts = []
for line_id in range(len(file_lemmatized)):
if line_id % 100 == 0:
print('injecting glossary: ' + str(line_id) + "/" + str(len(file_lemmatized)), end='\r')
doc = nlp(file_lemmatized[line_id])
matches = matcher(doc)
line_counter = 0
for match_id, start, end, ratio in matches:
if ratio > 90:
doc_pl = nlp_pl(file_pl_lemmatized[line_id])
matches_pl = matcher_pl(doc_pl)
for match_id_pl, start_pl, end_pl, ratio_pl in matches_pl:
if match_id_pl == glossary[glossary['source_lem'] == match_id].values[0][3]:
line_counter += 1
en.append(''.join(doc[:end].text + ' ' + train_glossary.loc[lambda df: df['source_lem'] == match_id]['result'].astype(str).values.flatten() + ' ' + doc[end:].text))
if line_counter == 0:
line_counter = 1
en.append(file_lemmatized[line_id])
translation_line_counts.append(line_counter)
print('saving files')
tlcs = copy.deepcopy(translation_line_counts)
translations = pd.read_csv(dev_path + 'expected.tsv', sep='\t', header=None, names=['text'])
translations['id'] = [x for x in range(len(translations))]
ctr = 0
sentence = ''
with open(dev_path + 'in.tsv.injected.crossvalidated', 'w') as file_en:
with open(dev_path + 'expected.tsv.injected.crossvalidated', 'w') as file_pl:
for i in range(len(en)):
if i > 0:
if en[i-1] != en[i]:
if ctr == 0:
sentence = translations.iloc[0]
translations.drop(sentence['id'], inplace=True)
sentence = sentence['text']
try:
ctr = tlcs.pop(0)
except:
pass
file_en.write(en[i])
file_pl.write(sentence + '\n')
ctr = ctr - 1
else:
try:
ctr = tlcs.pop(0) - 1
except:
pass
sentence = translations.iloc[0]
translations.drop(sentence['id'], inplace=True)
sentence = sentence['text']
file_en.write(en[i])
file_pl.write(sentence + '\n')