diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index bc78759..aeaa1b4 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -50,7 +50,10 @@ module GEval.Core import Debug.Trace +import Data.Singletons.TH + import GEval.Metric +import GEval.MetricsMechanics import GEval.EvaluationScheme import Data.Conduit @@ -509,13 +512,13 @@ gevalCoreOnSources Pearson _ = gevalCoreByCorrelationMeasure pearson gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta - intoWords - (Right . (\(ProbList es) -> es) . parseIntoProbList) + SAProbabilisticMultiLabelFMeasure + (Right . (\(ProbList es) -> es) . parseIntoProbList) where intoWords = Right . Data.Text.words gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta - parseAnnotations - parseObtainedAnnotations + SAProbabilisticSoftFMeasure + parseObtainedAnnotations -- and now more typical metrics, which: -- 1) parse the expected output @@ -524,29 +527,29 @@ gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilistic -- 4) aggregate the results -- 5) apply some final funtion on the aggregate -- 6) create a graph using the aggregate (applicable only to some metrics) -gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC logLossToLikehood noGraph +gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood doubleParser itemLogLossError averageC logLossToLikehood noGraph -gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords - (Right . parseIntoProbList) - (uncurry countLogLossOnProbList) - averageC - logLossToLikehood - noGraph +gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLikelihood + (Right . parseIntoProbList) + (uncurry countLogLossOnProbList) + averageC + logLossToLikehood + noGraph where intoWords = Right . Data.Text.words -gevalCoreOnSources MSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC id noGraph +gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE doubleParser itemSquaredError averageC id noGraph -gevalCoreOnSources RMSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC (** 0.5) noGraph +gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE doubleParser itemSquaredError averageC (** 0.5) noGraph -gevalCoreOnSources MAE _ = gevalCoreWithoutInput doubleParser doubleParser itemAbsoluteError averageC id noGraph +gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE doubleParser itemAbsoluteError averageC id noGraph -gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput doubleParser doubleParser smape averageC (* 100.0) noGraph +gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE doubleParser smape averageC (* 100.0) noGraph where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out)) -gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC id noGraph +gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss doubleParser itemLogLossError averageC id noGraph -gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph +gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl) bleuCombine (refs, sen) = bleuStep refs sen bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) @@ -556,15 +559,15 @@ gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.w | c == 0 && r > 0 = 0.0 | otherwise = exp (1.0 - (r /. c)) -gevalCoreOnSources GLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph +gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph where gleuFinal (m, t) = m /. t gleuCombine (refs, sen) = gleuStep refs sen gleuAgg = CC.foldl gleuFuse (0, 0) gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2) -gevalCoreOnSources WER _ = gevalCoreWithoutInput (Right . Prelude.words . unpack) (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph +gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph -gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id noGraph +gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy (Right . strip) hitOrMiss averageC id noGraph where hitOrMiss (exp, got) = -- first try to parse what we got as a probability distribution -- (like the one used for Likelikehood/LogLossHashed metric) @@ -589,12 +592,8 @@ gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . s tryReadingAsFloat :: Text -> Maybe Float tryReadingAsFloat = readMaybe . unpack -gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount countAgg (fMeasureOnCounts beta) noGraph +gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure outParser getCount countAgg (fMeasureOnCounts beta) noGraph where outParser = detected <=< (getValue . TR.double) - expParser = expected <=< (getValue . TR.decimal) - expected 1 = Right True - expected 0 = Right False - expected _ = Left "expected 0 or 1" -- output value could be a probability (for compatibility with other measures) detected prob | prob >= 0.0 && prob < detectionThreshold = Right False @@ -607,7 +606,7 @@ gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount (False, True) = (0, 0, 1) getCount (False, False) = (0, 0, 0) -gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip) (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph +gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph where predicted got = -- first try to parse what we got as a probability distribution -- (like the one used for Likelikehood/LogLossHashed metric) @@ -636,22 +635,22 @@ gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just M.lookupDefault 0 c gotMap)) $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) -gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations - parseObtainedAnnotations - getSoftCounts - countAgg - (fMeasureOnCounts beta) - noGraph +gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure + parseObtainedAnnotations + getSoftCounts + countAgg + (fMeasureOnCounts beta) + noGraph where getSoftCounts (expected, got) = (weightedMaxMatch matchScore expected got, Prelude.length expected, Prelude.length got) -gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings - parseLabeledClippings - count2DFScore - averageC - id - noGraph +gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput SASoft2DFMeasure + parseLabeledClippings + count2DFScore + averageC + id + noGraph where parseLabeledClippings = controlledParse lineLabeledClippingsParser count2DFScore (expected, got) = fMeasureOnCounts beta (tpArea, expArea, gotArea) @@ -659,35 +658,34 @@ gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledC expArea = totalArea expected gotArea = totalArea got -gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep noGraph +gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU parseClippings matchStep clippeuAgg finalStep noGraph where parseClippings = controlledParse lineClippingsParser - parseClippingSpecs = controlledParse lineClippingSpecsParser matchStep (clippingSpecs, clippings) = (maxMatch matchClippingToSpec clippingSpecs clippings, Prelude.length clippingSpecs, Prelude.length clippings) clippeuAgg = CC.foldl countFolder (0, 0, 0) finalStep counts = f2MeasureOnCounts counts -gevalCoreOnSources NMI _ = gevalCoreWithoutInput (Right . id) (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph +gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph -gevalCoreOnSources MAP _ = gevalCoreWithoutInput (Right . DLS.splitOn "\t" . unpack) - (Right . DLS.splitOn "\t" . unpack) - (\(e,g) -> calculateMAPForOneResult e g) - averageC - id - noGraph +gevalCoreOnSources MAP _ = gevalCoreWithoutInput SAMAP + (Right . DLS.splitOn "\t" . unpack) + (\(e,g) -> calculateMAPForOneResult e g) + averageC + id + noGraph -gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput parseBioSequenceIntoEntities parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph +gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph -gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput parseBioSequenceIntoEntitiesWithoutNormalization parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph +gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph -gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens - intoTokens - countHitsAndTotals - hitsAndTotalsAgg - (\(hits, total) -> hits /. total) - noGraph +gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy + intoTokens + countHitsAndTotals + hitsAndTotalsAgg + (\(hits, total) -> hits /. total) + noGraph where intoTokens = Right . Data.Text.words countHitsAndTotals :: ([Text], [Text]) -> (Int, Int) countHitsAndTotals (es, os) = @@ -703,17 +701,15 @@ gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens | otherwise = (h, t + 1) hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0) -gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords - (Right . parseIntoProbList) - (uncurry countLogLossOnProbList) - averageC - id - noGraph +gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput SAMultiLabelLogLoss + (Right . parseIntoProbList) + (uncurry countLogLossOnProbList) + averageC + id + noGraph where intoWords = Right . Data.Text.words -doubleParser = getValue . TR.double - helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource = gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource) where -- Unfortunately, we're parsing the distribution twice. We need to @@ -724,12 +720,12 @@ helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource = Right _ -> Right t Left m -> Left m -generalizedProbabilisticFMeasure beta parseBareEntities parseEntities = gevalCoreWithoutInput parseBareEntities - parseEntities - getProbabilisticCounts - probabilisticSoftAgg - (fMeasureOnProbabilisticCounts beta) - loessGraph +generalizedProbabilisticFMeasure beta metric parseEntities = gevalCoreWithoutInput metric + parseEntities + getProbabilisticCounts + probabilisticSoftAgg + (fMeasureOnProbabilisticCounts beta) + loessGraph where probabilisticSoftAgg :: Monad m => ConduitM ([Double], [Double], Double, Int) o m ([Double], [Double], Double, Int) probabilisticSoftAgg = CC.foldl probabilisticSoftFolder ([], [], fromInteger 0, 0) probabilisticSoftFolder (r1, p1, g1, e1) (r2, p2, g2, e2) = (r1 ++ r2, p1 ++ p2, g1 + g2, e1 + e2) @@ -754,7 +750,7 @@ gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) => LineSource (ResourceT m) -> -- ^ source to read the output m (MetricOutput) -- ^ metric values for the output against the expected output gevalCoreByCorrelationMeasure correlationFunction = - gevalCoreWithoutInput doubleParser doubleParser id correlationC finalStep noGraph + gevalCoreWithoutInput SAPearson doubleParser id correlationC finalStep noGraph where correlationC = CC.foldl (flip (:)) [] finalStep pairs = correlationFunction $ V.fromList pairs @@ -770,9 +766,9 @@ skipLineNumber fun = fun . snd -- | A helper function to run evaluation when the input is not needed to calculate the metric value. gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m) - => (Text -> Either String a) -- ^ parser for values in the expected output + => SAMetric t -> (Text -> Either String b) -- ^ parser for values in the actual output - -> ((a, b) -> c) -- ^ function which combines parsed values into a single value + -> ((ParsedExpectedType t, b) -> c) -- ^ function which combines parsed values into a single value -- (will be launched for each item, e.g. an error/cost function -- could be calculated here) -> (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into @@ -782,8 +778,9 @@ gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m) -> LineSource (ResourceT m) -- ^ source to read the expected output -> LineSource (ResourceT m) -- ^ source to read the output -> m (MetricOutput) -- ^ metric values for the output against the expected output -gevalCoreWithoutInput expParser outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream = +gevalCoreWithoutInput smetric outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream = gevalCoreWithoutInputOnItemTargets (liftOp expParser) (liftOp outParser) itemStep aggregator finalStep generateGraph expectedLineStream outLineStream + where expParser = expectedParser smetric gevalCoreWithoutInputOnItemTargets :: (MonadUnliftIO m, MonadThrow m, MonadIO m) => (ItemTarget -> Either String a) -- ^ parser for values in the expected output @@ -961,15 +958,3 @@ itemLogLossError (exp, out) | v >= 1.0 = 1.0 | v <= 0.0 = 0.0 | otherwise = v - -getValue :: Num a => Either String (a, Text) -> Either String a -getValue (Right (x, reminder)) = - if Data.Text.null reminder || Data.Text.head reminder == '\t' - then Right x - else Left "number expected" -getValue (Left s) = Left s - -controlledParse parser t = - case parseOnly parser t of - (Right v) -> Right v - (Left _) -> Left "cannot parse line" diff --git a/src/GEval/MetricsMechanics.hs b/src/GEval/MetricsMechanics.hs index f4ace4b..3ccafb7 100644 --- a/src/GEval/MetricsMechanics.hs +++ b/src/GEval/MetricsMechanics.hs @@ -4,6 +4,7 @@ {-# LANGUAGE GADTs #-} {-# LANGUAGE TemplateHaskell #-} {-# LANGUAGE EmptyCase #-} +{-# LANGUAGE UndecidableInstances #-} module GEval.MetricsMechanics where @@ -16,12 +17,15 @@ import Data.Text import Data.Text.Read as TR import qualified Data.List.Split as DLS import Data.Attoparsec.Text (parseOnly) +import Data.Maybe (catMaybes) import Control.Monad ((<=<)) -import GEval.Annotation (Annotation, parseAnnotations) -import GEval.Clippings (ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser) +import GEval.Annotation (Annotation, ObtainedAnnotation, parseAnnotations, parseObtainedAnnotations) +import GEval.Clippings (Clipping, ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser) import GEval.BIO (TaggedEntity, parseBioSequenceIntoEntities, parseBioSequenceIntoEntitiesWithoutNormalization) +import GEval.LogLossHashed (parseWordSpecs, wordSpecToPair) +import GEval.ProbList (ProbList(..), parseIntoProbList, WordWithProb(..)) -- | Helper type so that singleton can be used. -- | (The problem is that some metrics are parametrized by Double @@ -120,7 +124,7 @@ expectedParser SASoftFMeasure = parseAnnotations expectedParser SAProbabilisticMultiLabelFMeasure = intoWords expectedParser SAProbabilisticSoftFMeasure = parseAnnotations expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser -expectedParser SANMI = onlyStrip +expectedParser SANMI = Right . id expectedParser SALogLossHashed = onlyStrip expectedParser SALikelihoodHashed = onlyStrip expectedParser SACharMatch = Right @@ -136,6 +140,49 @@ expectedParser SAMultiLabelFMeasure = intoWords expectedParser SAMultiLabelLogLoss = intoWords expectedParser SAMultiLabelLikelihood = intoWords +type family ParsedOutputType (t :: AMetric) :: * where + ParsedOutputType ABLEU = [Text] + ParsedOutputType AGLEU = [Text] + ParsedOutputType AClippEU = [Clipping] + ParsedOutputType AMacroFMeasure = Maybe Text + ParsedOutputType ASoftFMeasure = [ObtainedAnnotation] + ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation] + ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb] + ParsedOutputType t = ParsedExpectedType t + +outputParser :: SAMetric t -> Text -> Either String (ParsedOutputType t) +outputParser SARMSE = expectedParser SARMSE +outputParser SAMSE = expectedParser SARMSE +outputParser SAPearson = expectedParser SAPearson +outputParser SASpearman = expectedParser SASpearman +outputParser SABLEU = intoWords +outputParser SAGLEU = intoWords +outputParser SAWER = expectedParser SAWER +outputParser SAAccuracy = expectedParser SAAccuracy +outputParser SAClippEU = controlledParse lineClippingsParser +outputParser SAFMeasure = probToZeroOneParser +outputParser SAMacroFMeasure = Right . predictedParser . strip +outputParser SASoftFMeasure = parseObtainedAnnotations +outputParser SAProbabilisticMultiLabelFMeasure = (Right . (\(ProbList es) -> es) . parseIntoProbList) +outputParser SAProbabilisticSoftFMeasure = parseObtainedAnnotations +outputParser SASoft2DFMeasure = expectedParser SASoft2DFMeasure +outputParser SANMI = expectedParser SANMI +outputParser SALogLossHashed = onlyStrip +outputParser SALikelihoodHashed = onlyStrip +outputParser SACharMatch = Right +outputParser SAMAP = splitByTabs +outputParser SALogLoss = doubleParser +outputParser SALikelihood = doubleParser +outputParser SABIOF1 = parseBioSequenceIntoEntities +outputParser SABIOF1Labels = parseBioSequenceIntoEntitiesWithoutNormalization +outputParser SATokenAccuracy = intoWords +outputParser SAMAE = doubleParser +outputParser SASMAPE = doubleParser +outputParser SAMultiLabelFMeasure = intoWords +outputParser SAMultiLabelLogLoss = intoWords +outputParser SAMultiLabelLikelihood = intoWords + + doubleParser = getValue . TR.double intoWords = Right . Data.Text.words @@ -148,6 +195,16 @@ onlyStrip = Right . strip justStrip = Right . Just . strip +predictedParser got = + -- first try to parse what we got as a probability distribution + -- (like the one used for Likelikehood/LogLossHashed metric) + case parseWordSpecs got of + Right wordSpecs -> if Prelude.null pairs + then Nothing + else Just $ snd $ Prelude.maximum pairs + where pairs = catMaybes $ Prelude.map wordSpecToPair wordSpecs + Left _ -> Just got + splitByTabs = Right . DLS.splitOn "\t" . unpack zeroOneParser = expected <=< (getValue . TR.decimal) @@ -155,6 +212,14 @@ zeroOneParser = expected <=< (getValue . TR.decimal) expected 0 = Right False expected _ = Left "expected 0 or 1" +probToZeroOneParser = detected <=< (getValue . TR.double) + where -- output value could be a probability (for compatibility with other measures) + detected prob + | prob >= 0.0 && prob < detectionThreshold = Right False + | prob >= detectionThreshold && prob <= 1.0 = Right True + | otherwise = Left "expected probability" + detectionThreshold = 0.5 + getValue :: Num a => Either String (a, Text) -> Either String a getValue (Right (x, reminder)) = if Data.Text.null reminder || Data.Text.head reminder == '\t' diff --git a/src/GEval/ProbList.hs b/src/GEval/ProbList.hs index 549c64c..843af5a 100644 --- a/src/GEval/ProbList.hs +++ b/src/GEval/ProbList.hs @@ -2,7 +2,7 @@ {-# LANGUAGE TypeFamilies #-} module GEval.ProbList - (parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..)) + (parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..), WordWithProb(..)) where import qualified Data.Text as T