minor fixes

This commit is contained in:
Kacper 2022-05-02 17:13:49 +02:00
parent e88c55edb7
commit 8b9d005112
2 changed files with 6 additions and 5 deletions

View File

@ -108,7 +108,8 @@ def write_to_files(lines_contents):
l = (len(lines_contents) / 10) * 8 l = (len(lines_contents) / 10) * 8
contents_train = lines_contents[:int(l)] contents_train = lines_contents[:int(l)]
contents_test = lines_contents[int(l):] contents_test = lines_contents[int(l):]
with open('train.conllu', 'a', encoding='utf-8') as train_f, open('test.conllu', 'a+', encoding='utf-8') as test_f: with open('train-pl.conllu', 'a', encoding='utf-8') as train_f, open('test-pl.conllu', 'a+', encoding='utf-8') as \
test_f:
for content in contents_train: for content in contents_train:
train_f.write(format_content(content)) train_f.write(format_content(content))
for content in contents_test: for content in contents_test:
@ -116,8 +117,7 @@ def write_to_files(lines_contents):
def main(): def main():
# path = sys.argv[1] path = sys.argv[1]
path = rf'C:\Users\Kacper Dudzic\Desktop\dane'
dir = Path(rf'{path}') dir = Path(rf'{path}')
for file in dir.glob('*'): for file in dir.glob('*'):
processed_contents = process_file(file) processed_contents = process_file(file)

View File

@ -68,6 +68,7 @@ embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
tag_dictionary=tag_dictionary, tag_dictionary=tag_dictionary,
tag_type='slot', use_crf=True) tag_type='slot', use_crf=True)
""" """
trainer = ModelTrainer(tagger, corpus) trainer = ModelTrainer(tagger, corpus)
trainer.train('slot-model-pl', trainer.train('slot-model-pl',
@ -80,5 +81,5 @@ try:
model = SequenceTagger.load('slot-model-pl/best-model.pt') model = SequenceTagger.load('slot-model-pl/best-model.pt')
except: except:
model = SequenceTagger.load('slot-model-pl/final-model.pt') model = SequenceTagger.load('slot-model-pl/final-model.pt')
print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split()), tablefmt='html')) print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split())))