chat-restaruacja/src/utils/nlu_utils.py

100 lines
3.4 KiB
Python
Raw Normal View History

2024-05-22 23:45:33 +02:00
from flair.data import Sentence
from flair.datasets import FlairDatapointDataset
def nolabel2o(line, i):
return 'O' if line[i] == 'NoLabel' else line[i]
def conllu2flair(sentences, label=None):
if label == "frame":
return conllu2flair_frame(sentences, label)
else:
return conllu2flair_slot(sentences, label)
def conllu2flair_frame(sentences, label=None):
fsentences = []
for sentence in sentences:
tokens = [token["form"] for token in sentence]
fsentence = Sentence(' '.join(tokens), use_tokenizer=False)
for i in range(len(fsentence)):
fsentence[i:i+1].add_label(label, sentence[i][label])
fsentences.append(fsentence)
return FlairDatapointDataset(fsentences)
def conllu2flair_slot(sentences, label=None):
fsentences = []
for sentence in sentences:
fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)
start_idx = None
end_idx = None
tag = None
if label:
for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):
if token[label].startswith('B-'):
if start_idx is not None:
fsentence[start_idx:end_idx+1].add_label(label, tag)
start_idx = idx
end_idx = idx
tag = token[label][2:]
elif token[label].startswith('I-'):
end_idx = idx
elif token[label] == 'O':
if start_idx is not None:
fsentence[start_idx:end_idx+1].add_label(label, tag)
start_idx = None
end_idx = None
tag = None
if start_idx is not None:
fsentence[start_idx:end_idx+1].add_label(label, tag)
fsentences.append(fsentence)
return FlairDatapointDataset(fsentences)
def __predict(model, csentence):
fsentence = conllu2flair([csentence])[0]
model.predict(fsentence)
return fsentence
def __csentence(sentence, label_type):
if label_type == "frame":
return [{'form': word } for word in sentence]
else:
return [{'form': word, 'slot': 'O'} for word in sentence]
def predict_single(model, sentence, label_type):
csentence = __csentence(sentence, label_type)
fsentence = __predict(model, csentence)
intent = {}
for span in fsentence.get_spans(label_type):
tag = span.get_label(label_type).value
if tag in intent:
intent[tag] += 1
else:
intent[tag] = 1
return max(intent, key=intent.get)
def predict_multiple(model, sentence, label_type):
csentence = __csentence(sentence, label_type)
fsentence = __predict(model, csentence)
return set(span.get_label(label_type).value for span in fsentence.get_spans(label_type))
def predict_and_annotate(model, sentence, label_type):
csentence = __csentence(sentence, label_type)
fsentence = __predict(model, csentence)
for span in fsentence.get_spans(label_type):
tag = span.get_label(label_type).value
if label_type == "frame":
csentence[span.tokens[0].idx-1]['frame'] = tag
else:
csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'
for token in span.tokens[1:]:
csentence[token.idx - 1]['slot'] = f'I-{tag}'
return csentence