sklep-internetowy-systemy-d.../chatbot/models/nlu_train2.py

139 lines
4.9 KiB
Python
Raw Normal View History

2024-06-10 20:05:18 +02:00
from conllu import parse_incr
from flair.data import Corpus
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from flair.data import Sentence
from flair.datasets import FlairDatapointDataset
import torch
def nolabel2o(line, i):
return 'O' if line[i] == 'NoLabel' else line[i]
def conllu2flair(sentences, label=None):
if label == "frame":
return conllu2flair_frame(sentences, label)
else:
return conllu2flair_slot(sentences, label)
def conllu2flair_frame(sentences, label=None):
fsentences = []
for sentence in sentences:
tokens = [token["form"] for token in sentence]
fsentence = Sentence(' '.join(tokens), use_tokenizer=False)
for i in range(len(fsentence)):
fsentence[i:i + 1].add_label(label, sentence[i][label])
fsentences.append(fsentence)
return FlairDatapointDataset(fsentences)
def conllu2flair_slot(sentences, label=None):
fsentences = []
for sentence in sentences:
fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)
start_idx = None
end_idx = None
tag = None
if label:
for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):
if token[label].startswith('B-'):
if start_idx is not None:
fsentence[start_idx:end_idx + 1].add_label(label, tag)
start_idx = idx
end_idx = idx
tag = token[label][2:]
elif token[label].startswith('I-'):
end_idx = idx
elif token[label] == 'O':
if start_idx is not None:
fsentence[start_idx:end_idx + 1].add_label(label, tag)
start_idx = None
end_idx = None
tag = None
if start_idx is not None:
fsentence[start_idx:end_idx + 1].add_label(label, tag)
fsentences.append(fsentence)
return FlairDatapointDataset(fsentences)
def predict_frame(model, sentence, label_type):
csentence = [{'form': word, 'slot': 'O'} for word in sentence]
fsentence = conllu2flair([csentence])[0]
model.predict(fsentence)
label_cnt = {}
for span in fsentence.get_spans(label_type):
tag = span.get_label(label_type).value
label_cnt[tag] = label_cnt.get(tag, 0) + 1
avg_label = max(label_cnt, key=label_cnt.get)
return avg_label
def predict_slot(model, sentence, label_type):
csentence = [{'form': word, 'slot': 'O'} for word in sentence]
fsentence = conllu2flair([csentence])[0]
model.predict(fsentence)
for span in fsentence.get_spans(label_type):
tag = span.get_label('slot').value
csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'
for token in span.tokens[1:]:
csentence[token.idx - 1]['slot'] = f'I-{tag}'
return csentence
class Model:
def __init__(self, train_dataset, test_dataset):
self.train_dataset = train_dataset
self.test_dataset = test_dataset
def train_model(self, label_type, field_parsers={}):
if torch.cuda.is_available():
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
with open(self.train_dataset, encoding='utf-8') as f:
trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
with open(self.test_dataset, encoding='utf-8') as f:
testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type))
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
embedding_types = [
WordEmbeddings('pl'),
FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'),
CharacterEmbeddings(),
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary,
tag_type=label_type, use_crf=True, tag_format="BIO")
frame_trainer = ModelTrainer(tagger, corpus)
frame_trainer.train(f'{label_type}-model', learning_rate=0.01, mini_batch_size=16, max_epochs=75,
train_with_dev=False)
model = Model(train_dataset='../data/test_dialog.conllu', test_dataset='../data/test_dialog.conllu')
model.train_model('frame')
model.train_model('slot', field_parsers={'slot': nolabel2o})