from conllu import parse_incr from flair.data import Corpus from flair.embeddings import StackedEmbeddings from flair.embeddings import WordEmbeddings from flair.embeddings import CharacterEmbeddings from flair.embeddings import FlairEmbeddings from flair.models import SequenceTagger from flair.trainers import ModelTrainer from flair.data import Sentence from flair.datasets import FlairDatapointDataset import torch def nolabel2o(line, i): return 'O' if line[i] == 'NoLabel' else line[i] def conllu2flair(sentences, label=None): if label == "frame": return conllu2flair_frame(sentences, label) else: return conllu2flair_slot(sentences, label) def conllu2flair_frame(sentences, label=None): fsentences = [] for sentence in sentences: tokens = [token["form"] for token in sentence] fsentence = Sentence(' '.join(tokens), use_tokenizer=False) for i in range(len(fsentence)): fsentence[i:i + 1].add_label(label, sentence[i][label]) fsentences.append(fsentence) return FlairDatapointDataset(fsentences) def conllu2flair_slot(sentences, label=None): fsentences = [] for sentence in sentences: fsentence = Sentence(' '.join(token['form'] for token in sentence), use_tokenizer=False) start_idx = None end_idx = None tag = None if label: for idx, (token, ftoken) in enumerate(zip(sentence, fsentence)): if token[label].startswith('B-'): if start_idx is not None: fsentence[start_idx:end_idx + 1].add_label(label, tag) start_idx = idx end_idx = idx tag = token[label][2:] elif token[label].startswith('I-'): end_idx = idx elif token[label] == 'O': if start_idx is not None: fsentence[start_idx:end_idx + 1].add_label(label, tag) start_idx = None end_idx = None tag = None if start_idx is not None: fsentence[start_idx:end_idx + 1].add_label(label, tag) fsentences.append(fsentence) return FlairDatapointDataset(fsentences) def predict_frame(model, sentence, label_type): if not sentence: return 'unknown' csentence = [{'form': word, 'slot': 'O'} for word in sentence] fsentence = conllu2flair([csentence])[0] model.predict(fsentence) label_cnt = {} for span in fsentence.get_spans(label_type): tag = span.get_label(label_type).value label_cnt[tag] = label_cnt.get(tag, 0) + 1 avg_label = max(label_cnt, key=label_cnt.get) return avg_label def predict_slot(model, sentence, label_type): if not sentence: return {'form': '', 'slot': 'unknown'}, csentence = [{'form': word, 'slot': 'O'} for word in sentence] fsentence = conllu2flair([csentence])[0] model.predict(fsentence) for span in fsentence.get_spans(label_type): tag = span.get_label('slot').value csentence[span.tokens[0].idx - 1]['slot'] = f'B-{tag}' for token in span.tokens[1:]: csentence[token.idx - 1]['slot'] = f'I-{tag}' return csentence class Model: def __init__(self, train_dataset, test_dataset): self.train_dataset = train_dataset self.test_dataset = test_dataset def train_model(self, label_type, field_parsers={}): if torch.cuda.is_available(): torch.backends.cudnn.enabled = False torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True with open(self.train_dataset, encoding='utf-8') as f: trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers)) with open(self.test_dataset, encoding='utf-8') as f: testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers)) corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type)) label_dictionary = corpus.make_label_dictionary(label_type=label_type) embedding_types = [ WordEmbeddings('pl'), FlairEmbeddings('pl-forward'), FlairEmbeddings('pl-backward'), CharacterEmbeddings(), ] embeddings = StackedEmbeddings(embeddings=embedding_types) tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO") frame_trainer = ModelTrainer(tagger, corpus) frame_trainer.train(f'{label_type}-model', learning_rate=0.01, mini_batch_size=16, max_epochs=75, train_with_dev=False) # model = Model(train_dataset='../data/test_dialog.conllu', test_dataset='../data/test_dialog.conllu') # model.train_model('frame') # model2 = Model(train_dataset='../data/test_dialog.conllu', test_dataset='../data/test_dialog.conllu') # model2.train_model('slot', field_parsers={'slot': nolabel2o})