Start using dependent types

This commit is contained in:
Filip Gralinski 2019-09-21 15:27:57 +02:00 committed by Filip Graliński
parent 6e20e79f5b
commit d5d177bc8d
3 changed files with 139 additions and 89 deletions

View File

@ -50,7 +50,10 @@ module GEval.Core
import Debug.Trace import Debug.Trace
import Data.Singletons.TH
import GEval.Metric import GEval.Metric
import GEval.MetricsMechanics
import GEval.EvaluationScheme import GEval.EvaluationScheme
import Data.Conduit import Data.Conduit
@ -509,12 +512,12 @@ gevalCoreOnSources Pearson _ = gevalCoreByCorrelationMeasure pearson
gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman
gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta
intoWords SAProbabilisticMultiLabelFMeasure
(Right . (\(ProbList es) -> es) . parseIntoProbList) (Right . (\(ProbList es) -> es) . parseIntoProbList)
where intoWords = Right . Data.Text.words where intoWords = Right . Data.Text.words
gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta
parseAnnotations SAProbabilisticSoftFMeasure
parseObtainedAnnotations parseObtainedAnnotations
-- and now more typical metrics, which: -- and now more typical metrics, which:
@ -524,9 +527,9 @@ gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilistic
-- 4) aggregate the results -- 4) aggregate the results
-- 5) apply some final funtion on the aggregate -- 5) apply some final funtion on the aggregate
-- 6) create a graph using the aggregate (applicable only to some metrics) -- 6) create a graph using the aggregate (applicable only to some metrics)
gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC logLossToLikehood noGraph gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood doubleParser itemLogLossError averageC logLossToLikehood noGraph
gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLikelihood
(Right . parseIntoProbList) (Right . parseIntoProbList)
(uncurry countLogLossOnProbList) (uncurry countLogLossOnProbList)
averageC averageC
@ -535,18 +538,18 @@ gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords
where where
intoWords = Right . Data.Text.words intoWords = Right . Data.Text.words
gevalCoreOnSources MSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC id noGraph gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE doubleParser itemSquaredError averageC id noGraph
gevalCoreOnSources RMSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC (** 0.5) noGraph gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE doubleParser itemSquaredError averageC (** 0.5) noGraph
gevalCoreOnSources MAE _ = gevalCoreWithoutInput doubleParser doubleParser itemAbsoluteError averageC id noGraph gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE doubleParser itemAbsoluteError averageC id noGraph
gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput doubleParser doubleParser smape averageC (* 100.0) noGraph gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE doubleParser smape averageC (* 100.0) noGraph
where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out)) where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out))
gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC id noGraph gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss doubleParser itemLogLossError averageC id noGraph
gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl) where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl)
bleuCombine (refs, sen) = bleuStep refs sen bleuCombine (refs, sen) = bleuStep refs sen
bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0)
@ -556,15 +559,15 @@ gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.w
| c == 0 && r > 0 = 0.0 | c == 0 && r > 0 = 0.0
| otherwise = exp (1.0 - (r /. c)) | otherwise = exp (1.0 - (r /. c))
gevalCoreOnSources GLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph
where gleuFinal (m, t) = m /. t where gleuFinal (m, t) = m /. t
gleuCombine (refs, sen) = gleuStep refs sen gleuCombine (refs, sen) = gleuStep refs sen
gleuAgg = CC.foldl gleuFuse (0, 0) gleuAgg = CC.foldl gleuFuse (0, 0)
gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2) gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2)
gevalCoreOnSources WER _ = gevalCoreWithoutInput (Right . Prelude.words . unpack) (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph
gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id noGraph gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy (Right . strip) hitOrMiss averageC id noGraph
where hitOrMiss (exp, got) = where hitOrMiss (exp, got) =
-- first try to parse what we got as a probability distribution -- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric) -- (like the one used for Likelikehood/LogLossHashed metric)
@ -589,12 +592,8 @@ gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . s
tryReadingAsFloat :: Text -> Maybe Float tryReadingAsFloat :: Text -> Maybe Float
tryReadingAsFloat = readMaybe . unpack tryReadingAsFloat = readMaybe . unpack
gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount countAgg (fMeasureOnCounts beta) noGraph gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure outParser getCount countAgg (fMeasureOnCounts beta) noGraph
where outParser = detected <=< (getValue . TR.double) where outParser = detected <=< (getValue . TR.double)
expParser = expected <=< (getValue . TR.decimal)
expected 1 = Right True
expected 0 = Right False
expected _ = Left "expected 0 or 1"
-- output value could be a probability (for compatibility with other measures) -- output value could be a probability (for compatibility with other measures)
detected prob detected prob
| prob >= 0.0 && prob < detectionThreshold = Right False | prob >= 0.0 && prob < detectionThreshold = Right False
@ -607,7 +606,7 @@ gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser
getCount (False, True) = (0, 0, 1) getCount (False, True) = (0, 0, 1)
getCount (False, False) = (0, 0, 0) getCount (False, False) = (0, 0, 0)
gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip) (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph
where predicted got = where predicted got =
-- first try to parse what we got as a probability distribution -- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric) -- (like the one used for Likelikehood/LogLossHashed metric)
@ -636,7 +635,7 @@ gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just
M.lookupDefault 0 c gotMap)) M.lookupDefault 0 c gotMap))
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure
parseObtainedAnnotations parseObtainedAnnotations
getSoftCounts getSoftCounts
countAgg countAgg
@ -646,7 +645,7 @@ gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotation
Prelude.length expected, Prelude.length expected,
Prelude.length got) Prelude.length got)
gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput SASoft2DFMeasure
parseLabeledClippings parseLabeledClippings
count2DFScore count2DFScore
averageC averageC
@ -659,30 +658,29 @@ gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledC
expArea = totalArea expected expArea = totalArea expected
gotArea = totalArea got gotArea = totalArea got
gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep noGraph gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU parseClippings matchStep clippeuAgg finalStep noGraph
where where
parseClippings = controlledParse lineClippingsParser parseClippings = controlledParse lineClippingsParser
parseClippingSpecs = controlledParse lineClippingSpecsParser
matchStep (clippingSpecs, clippings) = (maxMatch matchClippingToSpec clippingSpecs clippings, matchStep (clippingSpecs, clippings) = (maxMatch matchClippingToSpec clippingSpecs clippings,
Prelude.length clippingSpecs, Prelude.length clippingSpecs,
Prelude.length clippings) Prelude.length clippings)
clippeuAgg = CC.foldl countFolder (0, 0, 0) clippeuAgg = CC.foldl countFolder (0, 0, 0)
finalStep counts = f2MeasureOnCounts counts finalStep counts = f2MeasureOnCounts counts
gevalCoreOnSources NMI _ = gevalCoreWithoutInput (Right . id) (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCoreOnSources MAP _ = gevalCoreWithoutInput (Right . DLS.splitOn "\t" . unpack) gevalCoreOnSources MAP _ = gevalCoreWithoutInput SAMAP
(Right . DLS.splitOn "\t" . unpack) (Right . DLS.splitOn "\t" . unpack)
(\(e,g) -> calculateMAPForOneResult e g) (\(e,g) -> calculateMAPForOneResult e g)
averageC averageC
id id
noGraph noGraph
gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput parseBioSequenceIntoEntities parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput parseBioSequenceIntoEntitiesWithoutNormalization parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy
intoTokens intoTokens
countHitsAndTotals countHitsAndTotals
hitsAndTotalsAgg hitsAndTotalsAgg
@ -703,7 +701,7 @@ gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens
| otherwise = (h, t + 1) | otherwise = (h, t + 1)
hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0) hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0)
gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput SAMultiLabelLogLoss
(Right . parseIntoProbList) (Right . parseIntoProbList)
(uncurry countLogLossOnProbList) (uncurry countLogLossOnProbList)
averageC averageC
@ -712,8 +710,6 @@ gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
where where
intoWords = Right . Data.Text.words intoWords = Right . Data.Text.words
doubleParser = getValue . TR.double
helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource = helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource) gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource)
where -- Unfortunately, we're parsing the distribution twice. We need to where -- Unfortunately, we're parsing the distribution twice. We need to
@ -724,7 +720,7 @@ helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
Right _ -> Right t Right _ -> Right t
Left m -> Left m Left m -> Left m
generalizedProbabilisticFMeasure beta parseBareEntities parseEntities = gevalCoreWithoutInput parseBareEntities generalizedProbabilisticFMeasure beta metric parseEntities = gevalCoreWithoutInput metric
parseEntities parseEntities
getProbabilisticCounts getProbabilisticCounts
probabilisticSoftAgg probabilisticSoftAgg
@ -754,7 +750,7 @@ gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
LineSource (ResourceT m) -> -- ^ source to read the output LineSource (ResourceT m) -> -- ^ source to read the output
m (MetricOutput) -- ^ metric values for the output against the expected output m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreByCorrelationMeasure correlationFunction = gevalCoreByCorrelationMeasure correlationFunction =
gevalCoreWithoutInput doubleParser doubleParser id correlationC finalStep noGraph gevalCoreWithoutInput SAPearson doubleParser id correlationC finalStep noGraph
where correlationC = CC.foldl (flip (:)) [] where correlationC = CC.foldl (flip (:)) []
finalStep pairs = correlationFunction $ V.fromList pairs finalStep pairs = correlationFunction $ V.fromList pairs
@ -770,9 +766,9 @@ skipLineNumber fun = fun . snd
-- | A helper function to run evaluation when the input is not needed to calculate the metric value. -- | A helper function to run evaluation when the input is not needed to calculate the metric value.
gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m) gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> (Text -> Either String a) -- ^ parser for values in the expected output => SAMetric t
-> (Text -> Either String b) -- ^ parser for values in the actual output -> (Text -> Either String b) -- ^ parser for values in the actual output
-> ((a, b) -> c) -- ^ function which combines parsed values into a single value -> ((ParsedExpectedType t, b) -> c) -- ^ function which combines parsed values into a single value
-- (will be launched for each item, e.g. an error/cost function -- (will be launched for each item, e.g. an error/cost function
-- could be calculated here) -- could be calculated here)
-> (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into -> (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into
@ -782,8 +778,9 @@ gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
-> LineSource (ResourceT m) -- ^ source to read the expected output -> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output -> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricOutput) -- ^ metric values for the output against the expected output -> m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreWithoutInput expParser outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream = gevalCoreWithoutInput smetric outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream =
gevalCoreWithoutInputOnItemTargets (liftOp expParser) (liftOp outParser) itemStep aggregator finalStep generateGraph expectedLineStream outLineStream gevalCoreWithoutInputOnItemTargets (liftOp expParser) (liftOp outParser) itemStep aggregator finalStep generateGraph expectedLineStream outLineStream
where expParser = expectedParser smetric
gevalCoreWithoutInputOnItemTargets :: (MonadUnliftIO m, MonadThrow m, MonadIO m) gevalCoreWithoutInputOnItemTargets :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> (ItemTarget -> Either String a) -- ^ parser for values in the expected output => (ItemTarget -> Either String a) -- ^ parser for values in the expected output
@ -961,15 +958,3 @@ itemLogLossError (exp, out)
| v >= 1.0 = 1.0 | v >= 1.0 = 1.0
| v <= 0.0 = 0.0 | v <= 0.0 = 0.0
| otherwise = v | otherwise = v
getValue :: Num a => Either String (a, Text) -> Either String a
getValue (Right (x, reminder)) =
if Data.Text.null reminder || Data.Text.head reminder == '\t'
then Right x
else Left "number expected"
getValue (Left s) = Left s
controlledParse parser t =
case parseOnly parser t of
(Right v) -> Right v
(Left _) -> Left "cannot parse line"

View File

@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-} {-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE EmptyCase #-} {-# LANGUAGE EmptyCase #-}
{-# LANGUAGE UndecidableInstances #-}
module GEval.MetricsMechanics module GEval.MetricsMechanics
where where
@ -16,12 +17,15 @@ import Data.Text
import Data.Text.Read as TR import Data.Text.Read as TR
import qualified Data.List.Split as DLS import qualified Data.List.Split as DLS
import Data.Attoparsec.Text (parseOnly) import Data.Attoparsec.Text (parseOnly)
import Data.Maybe (catMaybes)
import Control.Monad ((<=<)) import Control.Monad ((<=<))
import GEval.Annotation (Annotation, parseAnnotations) import GEval.Annotation (Annotation, ObtainedAnnotation, parseAnnotations, parseObtainedAnnotations)
import GEval.Clippings (ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser) import GEval.Clippings (Clipping, ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser)
import GEval.BIO (TaggedEntity, parseBioSequenceIntoEntities, parseBioSequenceIntoEntitiesWithoutNormalization) import GEval.BIO (TaggedEntity, parseBioSequenceIntoEntities, parseBioSequenceIntoEntitiesWithoutNormalization)
import GEval.LogLossHashed (parseWordSpecs, wordSpecToPair)
import GEval.ProbList (ProbList(..), parseIntoProbList, WordWithProb(..))
-- | Helper type so that singleton can be used. -- | Helper type so that singleton can be used.
-- | (The problem is that some metrics are parametrized by Double -- | (The problem is that some metrics are parametrized by Double
@ -120,7 +124,7 @@ expectedParser SASoftFMeasure = parseAnnotations
expectedParser SAProbabilisticMultiLabelFMeasure = intoWords expectedParser SAProbabilisticMultiLabelFMeasure = intoWords
expectedParser SAProbabilisticSoftFMeasure = parseAnnotations expectedParser SAProbabilisticSoftFMeasure = parseAnnotations
expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser
expectedParser SANMI = onlyStrip expectedParser SANMI = Right . id
expectedParser SALogLossHashed = onlyStrip expectedParser SALogLossHashed = onlyStrip
expectedParser SALikelihoodHashed = onlyStrip expectedParser SALikelihoodHashed = onlyStrip
expectedParser SACharMatch = Right expectedParser SACharMatch = Right
@ -136,6 +140,49 @@ expectedParser SAMultiLabelFMeasure = intoWords
expectedParser SAMultiLabelLogLoss = intoWords expectedParser SAMultiLabelLogLoss = intoWords
expectedParser SAMultiLabelLikelihood = intoWords expectedParser SAMultiLabelLikelihood = intoWords
type family ParsedOutputType (t :: AMetric) :: * where
ParsedOutputType ABLEU = [Text]
ParsedOutputType AGLEU = [Text]
ParsedOutputType AClippEU = [Clipping]
ParsedOutputType AMacroFMeasure = Maybe Text
ParsedOutputType ASoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb]
ParsedOutputType t = ParsedExpectedType t
outputParser :: SAMetric t -> Text -> Either String (ParsedOutputType t)
outputParser SARMSE = expectedParser SARMSE
outputParser SAMSE = expectedParser SARMSE
outputParser SAPearson = expectedParser SAPearson
outputParser SASpearman = expectedParser SASpearman
outputParser SABLEU = intoWords
outputParser SAGLEU = intoWords
outputParser SAWER = expectedParser SAWER
outputParser SAAccuracy = expectedParser SAAccuracy
outputParser SAClippEU = controlledParse lineClippingsParser
outputParser SAFMeasure = probToZeroOneParser
outputParser SAMacroFMeasure = Right . predictedParser . strip
outputParser SASoftFMeasure = parseObtainedAnnotations
outputParser SAProbabilisticMultiLabelFMeasure = (Right . (\(ProbList es) -> es) . parseIntoProbList)
outputParser SAProbabilisticSoftFMeasure = parseObtainedAnnotations
outputParser SASoft2DFMeasure = expectedParser SASoft2DFMeasure
outputParser SANMI = expectedParser SANMI
outputParser SALogLossHashed = onlyStrip
outputParser SALikelihoodHashed = onlyStrip
outputParser SACharMatch = Right
outputParser SAMAP = splitByTabs
outputParser SALogLoss = doubleParser
outputParser SALikelihood = doubleParser
outputParser SABIOF1 = parseBioSequenceIntoEntities
outputParser SABIOF1Labels = parseBioSequenceIntoEntitiesWithoutNormalization
outputParser SATokenAccuracy = intoWords
outputParser SAMAE = doubleParser
outputParser SASMAPE = doubleParser
outputParser SAMultiLabelFMeasure = intoWords
outputParser SAMultiLabelLogLoss = intoWords
outputParser SAMultiLabelLikelihood = intoWords
doubleParser = getValue . TR.double doubleParser = getValue . TR.double
intoWords = Right . Data.Text.words intoWords = Right . Data.Text.words
@ -148,6 +195,16 @@ onlyStrip = Right . strip
justStrip = Right . Just . strip justStrip = Right . Just . strip
predictedParser got =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
case parseWordSpecs got of
Right wordSpecs -> if Prelude.null pairs
then Nothing
else Just $ snd $ Prelude.maximum pairs
where pairs = catMaybes $ Prelude.map wordSpecToPair wordSpecs
Left _ -> Just got
splitByTabs = Right . DLS.splitOn "\t" . unpack splitByTabs = Right . DLS.splitOn "\t" . unpack
zeroOneParser = expected <=< (getValue . TR.decimal) zeroOneParser = expected <=< (getValue . TR.decimal)
@ -155,6 +212,14 @@ zeroOneParser = expected <=< (getValue . TR.decimal)
expected 0 = Right False expected 0 = Right False
expected _ = Left "expected 0 or 1" expected _ = Left "expected 0 or 1"
probToZeroOneParser = detected <=< (getValue . TR.double)
where -- output value could be a probability (for compatibility with other measures)
detected prob
| prob >= 0.0 && prob < detectionThreshold = Right False
| prob >= detectionThreshold && prob <= 1.0 = Right True
| otherwise = Left "expected probability"
detectionThreshold = 0.5
getValue :: Num a => Either String (a, Text) -> Either String a getValue :: Num a => Either String (a, Text) -> Either String a
getValue (Right (x, reminder)) = getValue (Right (x, reminder)) =
if Data.Text.null reminder || Data.Text.head reminder == '\t' if Data.Text.null reminder || Data.Text.head reminder == '\t'

View File

@ -2,7 +2,7 @@
{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeFamilies #-}
module GEval.ProbList module GEval.ProbList
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..)) (parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..), WordWithProb(..))
where where
import qualified Data.Text as T import qualified Data.Text as T