100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
from flair.data import Sentence
|
|
from flair.datasets import FlairDatapointDataset
|
|
|
|
def nolabel2o(line, i):
|
|
return 'O' if line[i] == 'NoLabel' else line[i]
|
|
|
|
def conllu2flair(sentences, label=None):
|
|
if label == "frame":
|
|
return conllu2flair_frame(sentences, label)
|
|
else:
|
|
return conllu2flair_slot(sentences, label)
|
|
|
|
def conllu2flair_frame(sentences, label=None):
|
|
fsentences = []
|
|
for sentence in sentences:
|
|
tokens = [token["form"] for token in sentence]
|
|
fsentence = Sentence(' '.join(tokens), use_tokenizer=False)
|
|
|
|
for i in range(len(fsentence)):
|
|
fsentence[i:i+1].add_label(label, sentence[i][label])
|
|
|
|
fsentences.append(fsentence)
|
|
|
|
return FlairDatapointDataset(fsentences)
|
|
|
|
def conllu2flair_slot(sentences, label=None):
|
|
fsentences = []
|
|
|
|
for sentence in sentences:
|
|
fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False)
|
|
start_idx = None
|
|
end_idx = None
|
|
tag = None
|
|
|
|
if label:
|
|
for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)):
|
|
if token[label].startswith('B-'):
|
|
start_idx = idx
|
|
end_idx = idx
|
|
tag = token[label][2:]
|
|
elif token[label].startswith('I-'):
|
|
end_idx = idx
|
|
elif token[label] == 'O':
|
|
if start_idx is not None:
|
|
fsentence[start_idx:end_idx+1].add_label(label, tag)
|
|
start_idx = None
|
|
end_idx = None
|
|
tag = None
|
|
|
|
if start_idx is not None:
|
|
fsentence[start_idx:end_idx+1].add_label(label, tag)
|
|
|
|
fsentences.append(fsentence)
|
|
return FlairDatapointDataset(fsentences)
|
|
|
|
def __predict(model, csentence):
|
|
fsentence = conllu2flair([csentence])[0]
|
|
model.predict(fsentence)
|
|
return fsentence
|
|
|
|
def __csentence(sentence, label_type):
|
|
if label_type == "frame":
|
|
return [{'form': word } for word in sentence]
|
|
else:
|
|
return [{'form': word, 'slot': 'O'} for word in sentence]
|
|
|
|
def predict_single(model, sentence, label_type):
|
|
csentence = __csentence(sentence, label_type)
|
|
fsentence = __predict(model, csentence)
|
|
intent = {}
|
|
|
|
for span in fsentence.get_spans(label_type):
|
|
tag = span.get_label(label_type).value
|
|
if tag in intent:
|
|
intent[tag] += 1
|
|
else:
|
|
intent[tag] = 1
|
|
|
|
return max(intent, key=intent.get)
|
|
|
|
def predict_multiple(model, sentence, label_type):
|
|
csentence = __csentence(sentence, label_type)
|
|
fsentence = __predict(model, csentence)
|
|
|
|
return set(span.get_label(label_type).value for span in fsentence.get_spans(label_type))
|
|
|
|
def predict_and_annotate(model, sentence, label_type):
|
|
csentence = __csentence(sentence, label_type)
|
|
fsentence = __predict(model, csentence)
|
|
|
|
for span in fsentence.get_spans(label_type):
|
|
tag = span.get_label(label_type).value
|
|
if label_type == "frame":
|
|
csentence[span.tokens[0].idx-1]['frame'] = tag
|
|
else:
|
|
csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}'
|
|
for token in span.tokens[1:]:
|
|
csentence[token.idx - 1]['slot'] = f'I-{tag}'
|
|
|
|
return csentence |