minor fixes
This commit is contained in:
parent
e88c55edb7
commit
8b9d005112
@ -108,7 +108,8 @@ def write_to_files(lines_contents):
|
|||||||
l = (len(lines_contents) / 10) * 8
|
l = (len(lines_contents) / 10) * 8
|
||||||
contents_train = lines_contents[:int(l)]
|
contents_train = lines_contents[:int(l)]
|
||||||
contents_test = lines_contents[int(l):]
|
contents_test = lines_contents[int(l):]
|
||||||
with open('train.conllu', 'a', encoding='utf-8') as train_f, open('test.conllu', 'a+', encoding='utf-8') as test_f:
|
with open('train-pl.conllu', 'a', encoding='utf-8') as train_f, open('test-pl.conllu', 'a+', encoding='utf-8') as \
|
||||||
|
test_f:
|
||||||
for content in contents_train:
|
for content in contents_train:
|
||||||
train_f.write(format_content(content))
|
train_f.write(format_content(content))
|
||||||
for content in contents_test:
|
for content in contents_test:
|
||||||
@ -116,8 +117,7 @@ def write_to_files(lines_contents):
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# path = sys.argv[1]
|
path = sys.argv[1]
|
||||||
path = rf'C:\Users\Kacper Dudzic\Desktop\dane'
|
|
||||||
dir = Path(rf'{path}')
|
dir = Path(rf'{path}')
|
||||||
for file in dir.glob('*'):
|
for file in dir.glob('*'):
|
||||||
processed_contents = process_file(file)
|
processed_contents = process_file(file)
|
||||||
|
@ -68,6 +68,7 @@ embeddings = StackedEmbeddings(embeddings=embedding_types)
|
|||||||
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
|
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
|
||||||
tag_dictionary=tag_dictionary,
|
tag_dictionary=tag_dictionary,
|
||||||
tag_type='slot', use_crf=True)
|
tag_type='slot', use_crf=True)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
trainer = ModelTrainer(tagger, corpus)
|
trainer = ModelTrainer(tagger, corpus)
|
||||||
trainer.train('slot-model-pl',
|
trainer.train('slot-model-pl',
|
||||||
@ -80,5 +81,5 @@ try:
|
|||||||
model = SequenceTagger.load('slot-model-pl/best-model.pt')
|
model = SequenceTagger.load('slot-model-pl/best-model.pt')
|
||||||
except:
|
except:
|
||||||
model = SequenceTagger.load('slot-model-pl/final-model.pt')
|
model = SequenceTagger.load('slot-model-pl/final-model.pt')
|
||||||
|
|
||||||
print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split()), tablefmt='html'))
|
print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split())))
|
||||||
|
Loading…
Reference in New Issue
Block a user