eval
This commit is contained in:
parent
5027466e25
commit
42b333a405
@ -26,4 +26,4 @@ df = pd.DataFrame(data)
|
||||
|
||||
print(df.head(5))
|
||||
|
||||
df.to_csv(r'data.csv',index=False, sep='\t')
|
||||
df.to_csv(r'data.tsv',index=False, sep='\t')
|
@ -27,20 +27,14 @@ class ML_NLU:
|
||||
|
||||
def conllu2flair(self, sentences, label=None):
|
||||
fsentences = []
|
||||
|
||||
for sentence in sentences:
|
||||
fsentence = Sentence()
|
||||
|
||||
for token in sentence:
|
||||
ftoken = Token(token['form'])
|
||||
|
||||
if label:
|
||||
ftoken.add_tag(label, token[label])
|
||||
|
||||
fsentence.add_token(ftoken)
|
||||
|
||||
fsentences.append(fsentence)
|
||||
|
||||
return SentenceDataset(fsentences)
|
||||
|
||||
|
||||
|
48
eval.py
Normal file
48
eval.py
Normal file
@ -0,0 +1,48 @@
|
||||
import pandas as pd
|
||||
from tabulate import tabulate
|
||||
from flair.data import Sentence, Token
|
||||
from flair.datasets import SentenceDataset
|
||||
from flair.models import SequenceTagger
|
||||
|
||||
def conllu2flair(sentences, label=None):
|
||||
fsentences = []
|
||||
for sentence in sentences:
|
||||
fsentence = Sentence()
|
||||
for token in sentence:
|
||||
ftoken = Token(token['form'])
|
||||
if label:
|
||||
ftoken.add_tag(label, token[label])
|
||||
fsentence.add_token(ftoken)
|
||||
fsentences.append(fsentence)
|
||||
return SentenceDataset(fsentences)
|
||||
|
||||
def predict(frame_model, sentence):
|
||||
csentence = [{'form': word} for word in sentence]
|
||||
fsentence = conllu2flair([csentence])[0]
|
||||
frame_model.predict(fsentence)
|
||||
possible_intents = {}
|
||||
for token in fsentence:
|
||||
for intent in token.annotation_layers["frame"]:
|
||||
if(intent.value in possible_intents):
|
||||
possible_intents[intent.value] += intent.score
|
||||
else:
|
||||
possible_intents[intent.value] = intent.score
|
||||
return max(possible_intents)
|
||||
|
||||
frame_model = SequenceTagger.load('frame-model/final-model.pt')
|
||||
data = []
|
||||
with open('data.tsv') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
for line in lines[1:]:
|
||||
data.append((line.split("\t")[0], line.split("\t")[1]))
|
||||
|
||||
correct = 0
|
||||
for sentence in data:
|
||||
predicted_intent = predict(frame_model, sentence[0].split())
|
||||
if predicted_intent == sentence[1].replace('\n',''):
|
||||
correct+=1
|
||||
else:
|
||||
print(predicted_intent + " != " + sentence[1].replace('\n',''))
|
||||
|
||||
print(f"{correct/len(data)} {correct}/{len(data)}")
|
12
train.py
12
train.py
@ -70,9 +70,9 @@ frame_tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
|
||||
# train_with_dev=False)
|
||||
|
||||
|
||||
frame_trainer = ModelTrainer(frame_tagger, frame_corpus)
|
||||
frame_trainer.train('frame-model',
|
||||
learning_rate=0.1,
|
||||
mini_batch_size=32,
|
||||
max_epochs=100,
|
||||
train_with_dev=False)
|
||||
# frame_trainer = ModelTrainer(frame_tagger, frame_corpus)
|
||||
# frame_trainer.train('frame-model',
|
||||
# learning_rate=0.1,
|
||||
# mini_batch_size=32,
|
||||
# max_epochs=100,
|
||||
# train_with_dev=False)
|
Loading…
Reference in New Issue
Block a user