2024-06-10 20:05:18 +02:00
|
|
|
from conllu import parse_incr
|
|
|
|
from flair.data import Corpus
|
|
|
|
from flair.embeddings import StackedEmbeddings
|
|
|
|
from flair.embeddings import WordEmbeddings
|
|
|
|
from flair.embeddings import CharacterEmbeddings
|
|
|
|
from flair.embeddings import FlairEmbeddings
|
|
|
|
from flair.models import SequenceTagger
|
|
|
|
from flair.trainers import ModelTrainer
|
|
|
|
from flair.data import Sentence
|
|
|
|
from flair.datasets import FlairDatapointDataset
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def nolabel2o(line, i):
|
|
|
|
return 'O' if line[i] == 'NoLabel' else line[i]
|
|
|
|
|
|
|
|
|
|
|
|
def conllu2flair(sentences, label=None):
|
|
|
|
if label == "frame":
|
|
|
|
return conllu2flair_frame(sentences, label)
|
|
|
|
else:
|
|
|
|
return conllu2flair_slot(sentences, label)
|
|
|
|
|
|
|
|
|
|
|
|
def conllu2flair_frame(sentences, label=None):
|
|
|
|
fsentences = []
|
|
|
|
for sentence in sentences:
|
|
|
|
tokens = [token["form"] for token in sentence]
|
|
|
|
fsentence = Sentence(' '.join(tokens), use_tokenizer=False)
|
|
|
|
|
|
|
|
for i in range(len(fsentence)):
|
|
|
|
fsentence[i:i + 1].add_label(label, sentence[i][label])
|
|
|
|
|
|
|
|
fsentences.append(fsentence)
|
|
|
|
|
|
|
|
return FlairDatapointDataset(fsentences)
|
|
|
|
|
|
|
|
|
|
|
|
def conllu2flair_slot(sentences, label=None):
|
|
|
|
fsentences = []
|
|
|
|
for sentence in sentences:
|
|
|
|
fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)
|
|
|
|
start_idx = None
|
|
|
|
end_idx = None
|
|
|
|
tag = None
|
|
|
|
|
|
|
|
if label:
|
|
|
|
for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):
|
|
|
|
if token[label].startswith('B-'):
|
|
|
|
if start_idx is not None:
|
|
|
|
fsentence[start_idx:end_idx + 1].add_label(label, tag)
|
|
|
|
start_idx = idx
|
|
|
|
end_idx = idx
|
|
|
|
tag = token[label][2:]
|
|
|
|
elif token[label].startswith('I-'):
|
|
|
|
end_idx = idx
|
|
|
|
elif token[label] == 'O':
|
|
|
|
if start_idx is not None:
|
|
|
|
fsentence[start_idx:end_idx + 1].add_label(label, tag)
|
|
|
|
start_idx = None
|
|
|
|
end_idx = None
|
|
|
|
tag = None
|
|
|
|
|
|
|
|
if start_idx is not None:
|
|
|
|
fsentence[start_idx:end_idx + 1].add_label(label, tag)
|
|
|
|
|
|
|
|
fsentences.append(fsentence)
|
|
|
|
return FlairDatapointDataset(fsentences)
|
|
|
|
|
|
|
|
|
|
|
|
def predict_frame(model, sentence, label_type):
|
2024-06-11 18:10:51 +02:00
|
|
|
if not sentence:
|
|
|
|
return 'unknown'
|
2024-06-10 20:05:18 +02:00
|
|
|
csentence = [{'form': word, 'slot': 'O'} for word in sentence]
|
|
|
|
fsentence = conllu2flair([csentence])[0]
|
|
|
|
model.predict(fsentence)
|
|
|
|
label_cnt = {}
|
|
|
|
for span in fsentence.get_spans(label_type):
|
|
|
|
tag = span.get_label(label_type).value
|
|
|
|
label_cnt[tag] = label_cnt.get(tag, 0) + 1
|
|
|
|
|
|
|
|
avg_label = max(label_cnt, key=label_cnt.get)
|
|
|
|
|
|
|
|
return avg_label
|
|
|
|
|
|
|
|
|
|
|
|
def predict_slot(model, sentence, label_type):
|
2024-06-11 18:10:51 +02:00
|
|
|
if not sentence:
|
|
|
|
return {'form': '', 'slot': 'unknown'},
|
2024-06-10 20:05:18 +02:00
|
|
|
csentence = [{'form': word, 'slot': 'O'} for word in sentence]
|
|
|
|
fsentence = conllu2flair([csentence])[0]
|
|
|
|
model.predict(fsentence)
|
|
|
|
|
|
|
|
for span in fsentence.get_spans(label_type):
|
|
|
|
tag = span.get_label('slot').value
|
|
|
|
csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'
|
|
|
|
for token in span.tokens[1:]:
|
|
|
|
csentence[token.idx - 1]['slot'] = f'I-{tag}'
|
|
|
|
|
|
|
|
return csentence
|
|
|
|
|
|
|
|
|
|
|
|
class Model:
|
|
|
|
def __init__(self, train_dataset, test_dataset):
|
|
|
|
self.train_dataset = train_dataset
|
|
|
|
self.test_dataset = test_dataset
|
|
|
|
|
|
|
|
def train_model(self, label_type, field_parsers={}):
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.backends.cudnn.enabled = False
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
|
with open(self.train_dataset, encoding='utf-8') as f:
|
|
|
|
trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
|
|
|
|
with open(self.test_dataset, encoding='utf-8') as f:
|
|
|
|
testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
|
2024-06-11 18:10:51 +02:00
|
|
|
|
2024-06-10 20:05:18 +02:00
|
|
|
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type))
|
|
|
|
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
|
|
|
|
embedding_types = [
|
|
|
|
WordEmbeddings('pl'),
|
|
|
|
FlairEmbeddings('pl-forward'),
|
|
|
|
FlairEmbeddings('pl-backward'),
|
|
|
|
CharacterEmbeddings(),
|
|
|
|
]
|
|
|
|
|
|
|
|
embeddings = StackedEmbeddings(embeddings=embedding_types)
|
|
|
|
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary,
|
|
|
|
tag_type=label_type, use_crf=True, tag_format="BIO")
|
|
|
|
|
|
|
|
frame_trainer = ModelTrainer(tagger, corpus)
|
|
|
|
frame_trainer.train(f'{label_type}-model', learning_rate=0.01, mini_batch_size=16, max_epochs=75,
|
|
|
|
train_with_dev=False)
|
|
|
|
|
|
|
|
|
2024-06-11 19:30:24 +02:00
|
|
|
# model = Model(train_dataset='../data/test_dialog.conllu', test_dataset='../data/test_dialog.conllu')
|
|
|
|
# model.train_model('frame')
|
2024-06-10 22:27:30 +02:00
|
|
|
# model2 = Model(train_dataset='../data/test_dialog.conllu', test_dataset='../data/test_dialog.conllu')
|
|
|
|
# model2.train_model('slot', field_parsers={'slot': nolabel2o})
|