From 169931be8f2ba7864b06c3daec449fab986c87bb Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Thu, 6 Feb 2020 15:06:41 +0100 Subject: [PATCH] Define FLC-F-score --- src/GEval/Annotation.hs | 10 ++++++++++ src/GEval/Core.hs | 7 +++++++ src/GEval/Metric.hs | 6 ++++++ src/GEval/MetricsMechanics.hs | 17 +++++++++++++++++ src/GEval/PrecisionRecall.hs | 19 ++++++++++++++----- src/GEval/Probability.hs | 26 +++++++++++++++++++++++++- 6 files changed, 79 insertions(+), 6 deletions(-) diff --git a/src/GEval/Annotation.hs b/src/GEval/Annotation.hs index 950c93d..fbdc283 100644 --- a/src/GEval/Annotation.hs +++ b/src/GEval/Annotation.hs @@ -28,6 +28,15 @@ data Annotation = Annotation T.Text IS.IntSet data ObtainedAnnotation = ObtainedAnnotation Annotation Double deriving (Eq, Show) +instance Coverable Annotation where + coveredScore (Annotation labelA intSetA) (Annotation labelB intSetB) + | labelA == labelB = (intSetLength intersect) /. (intSetLength intSetA) + where intSetLength = Prelude.length . IS.toList + intersect = intSetA `IS.intersection` intSetB + disjoint (Annotation labelA intSetA) (Annotation labelB intSetB) + | labelA == labelB = intSetA `IS.disjoint` intSetB + | otherwise = True + instance EntityWithProbability ObtainedAnnotation where type BareEntity ObtainedAnnotation = Annotation getBareEntity (ObtainedAnnotation annotation _) = annotation @@ -38,6 +47,7 @@ instance EntityWithProbability ObtainedAnnotation where where intSetLength = Prelude.length . IS.toList intersect = intSetA `IS.intersection` intSetB +instance CoverableEntityWithProbability ObtainedAnnotation parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation] parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 4a9b033..c72697d 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -638,6 +638,9 @@ generalizedProbabilisticFMeasure beta metric = gevalCoreWithoutInput metric countAgg :: (Num n, Num v, Monad m) => ConduitM (n, v, v) o m (n, v, v) countAgg = CC.foldl countFolder (fromInteger 0, fromInteger 0, fromInteger 0) +countFragAgg :: (Num n, Num v, Monad m) => ConduitM (n, n, v, v) o m (n, n, v, v) +countFragAgg = CC.foldl countFragFolder (fromInteger 0, fromInteger 0, fromInteger 0, fromInteger 0) + gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) => (V.Vector (Double, Double) -> Double) -> -- ^ correlation function LineSource (ResourceT m) -> -- ^ source to read the expected output @@ -813,6 +816,10 @@ continueGEvalCalculations SASoftFMeasure (SoftFMeasure beta) = defineContinuatio (fMeasureOnCounts beta) noGraph +continueGEvalCalculations SAFLCFMeasure (FLCFMeasure beta) = defineContinuation countFragAgg + (fMeasureOnFragCounts beta) + noGraph + continueGEvalCalculations SASoft2DFMeasure (Soft2DFMeasure beta) = defineContinuation (CC.map (fMeasureOnCounts beta) .| averageC) id noGraph diff --git a/src/GEval/Metric.hs b/src/GEval/Metric.hs index c5ac7ba..88bdfd4 100644 --- a/src/GEval/Metric.hs +++ b/src/GEval/Metric.hs @@ -31,6 +31,7 @@ data Metric = RMSE | MSE | Pearson | Spearman | BLEU | GLEU | WER | Accuracy | C | MultiLabelLogLoss | MultiLabelLikelihood | SoftFMeasure Double | ProbabilisticMultiLabelFMeasure Double | ProbabilisticSoftFMeasure Double | Soft2DFMeasure Double + | FLCFMeasure Double -- it would be better to avoid infinite recursion here -- `Mean (Mean BLEU)` is not useful, but as it would mean -- a larger refactor, we will postpone this @@ -53,6 +54,7 @@ instance Show Metric where show (ProbabilisticMultiLabelFMeasure beta) = "Probabilistic-MultiLabel-F" ++ (show beta) show (ProbabilisticSoftFMeasure beta) = "Probabilistic-Soft-F" ++ (show beta) show (Soft2DFMeasure beta) = "Soft2D-F" ++ (show beta) + show (FLCFMeasure beta) = "FLC-F" ++ (show beta) show NMI = "NMI" show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if nbOfBits == defaultLogLossHashedSize @@ -95,6 +97,9 @@ instance Read Metric where readsPrec _ ('A':'c':'c':'u':'r':'a':'c':'y':theRest) = [(Accuracy, theRest)] readsPrec _ ('C':'l':'i':'p':'p':'E':'U':theRest) = [(ClippEU, theRest)] readsPrec _ ('N':'M':'I':theRest) = [(NMI, theRest)] + readsPrec p ('F':'L':'C':'-':'F':theRest) = case readsPrec p theRest of + [(beta, theRest)] -> [(FLCFMeasure beta, theRest)] + _ -> [] readsPrec p ('F':theRest) = case readsPrec p theRest of [(beta, theRest)] -> [(FMeasure beta, theRest)] _ -> [] @@ -156,6 +161,7 @@ getMetricOrdering (SoftFMeasure _) = TheHigherTheBetter getMetricOrdering (ProbabilisticMultiLabelFMeasure _) = TheHigherTheBetter getMetricOrdering (ProbabilisticSoftFMeasure _) = TheHigherTheBetter getMetricOrdering (Soft2DFMeasure _) = TheHigherTheBetter +getMetricOrdering (FLCFMeasure _) = TheHigherTheBetter getMetricOrdering NMI = TheHigherTheBetter getMetricOrdering (LogLossHashed _) = TheLowerTheBetter getMetricOrdering (LikelihoodHashed _) = TheHigherTheBetter diff --git a/src/GEval/MetricsMechanics.hs b/src/GEval/MetricsMechanics.hs index 72c27b1..0c15668 100644 --- a/src/GEval/MetricsMechanics.hs +++ b/src/GEval/MetricsMechanics.hs @@ -50,6 +50,7 @@ singletons [d|data AMetric = ARMSE | AMSE | APearson | ASpearman | ABLEU | AGLEU | ABIOF1 | ABIOF1Labels | ATokenAccuracy | ASegmentAccuracy | ALikelihoodHashed | AMAE | ASMAPE | AMultiLabelFMeasure | AMultiLabelLogLoss | AMultiLabelLikelihood | ASoftFMeasure | AProbabilisticMultiLabelFMeasure | AProbabilisticSoftFMeasure | ASoft2DFMeasure + | AFLCFMeasure deriving (Eq) |] @@ -83,6 +84,7 @@ toHelper (MultiLabelFMeasure _) = AMultiLabelFMeasure toHelper MultiLabelLogLoss = AMultiLabelLogLoss toHelper MultiLabelLikelihood = AMultiLabelLikelihood toHelper (SoftFMeasure _) = ASoftFMeasure +toHelper (FLCFMeasure _) = AFLCFMeasure toHelper (ProbabilisticMultiLabelFMeasure _) = AProbabilisticMultiLabelFMeasure toHelper (ProbabilisticSoftFMeasure _) = AProbabilisticSoftFMeasure toHelper (Soft2DFMeasure _) = ASoft2DFMeasure @@ -104,6 +106,7 @@ type family ParsedExpectedType (t :: AMetric) :: * where ParsedExpectedType AFMeasure = Bool ParsedExpectedType AMacroFMeasure = Maybe Text ParsedExpectedType ASoftFMeasure = [Annotation] + ParsedExpectedType AFLCFMeasure = [Annotation] ParsedExpectedType AProbabilisticMultiLabelFMeasure = [Text] ParsedExpectedType AProbabilisticSoftFMeasure = [Annotation] ParsedExpectedType ASoft2DFMeasure = [LabeledClipping] @@ -137,6 +140,7 @@ expectedParser SAClippEU = controlledParse lineClippingSpecsParser expectedParser SAFMeasure = zeroOneParser expectedParser SAMacroFMeasure = justStrip expectedParser SASoftFMeasure = parseAnnotations +expectedParser SAFLCFMeasure = parseAnnotations expectedParser SAProbabilisticMultiLabelFMeasure = intoWords expectedParser SAProbabilisticSoftFMeasure = parseAnnotations expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser @@ -163,6 +167,7 @@ type family ParsedOutputType (t :: AMetric) :: * where ParsedOutputType AClippEU = [Clipping] ParsedOutputType AMacroFMeasure = Maybe Text ParsedOutputType ASoftFMeasure = [ObtainedAnnotation] + ParsedOutputType AFLCFMeasure = [ObtainedAnnotation] ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation] ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb] ParsedOutputType AMultiLabelLikelihood = ProbList @@ -182,6 +187,7 @@ outputParser SAClippEU = controlledParse lineClippingsParser outputParser SAFMeasure = probToZeroOneParser outputParser SAMacroFMeasure = Right . predictedParser . strip outputParser SASoftFMeasure = parseObtainedAnnotations +outputParser SAFLCFMeasure = parseObtainedAnnotations outputParser SAProbabilisticMultiLabelFMeasure = (Right . (\(ProbList es) -> es) . parseIntoProbList) outputParser SAProbabilisticSoftFMeasure = parseObtainedAnnotations outputParser SASoft2DFMeasure = expectedParser SASoft2DFMeasure @@ -208,6 +214,7 @@ type family ItemIntermediateRepresentationType (t :: AMetric) :: * where ItemIntermediateRepresentationType AFMeasure = (Int, Int, Int) ItemIntermediateRepresentationType AMacroFMeasure = (Maybe Text, Maybe Text, Maybe Text) ItemIntermediateRepresentationType ASoftFMeasure = (Double, Int, Int) + ItemIntermediateRepresentationType AFLCFMeasure = (Double, Double, Int, Int) ItemIntermediateRepresentationType ASoft2DFMeasure = (Integer, Integer, Integer) ItemIntermediateRepresentationType AClippEU = (Int, Int, Int) ItemIntermediateRepresentationType ANMI = (Text, Text) @@ -238,6 +245,7 @@ itemStep SAClippEU = clippEUMatchStep itemStep SAFMeasure = getCount itemStep SAMacroFMeasure = getClassesInvolved itemStep SASoftFMeasure = getSoftCounts +itemStep SAFLCFMeasure = getFragCounts itemStep SAProbabilisticMultiLabelFMeasure = getProbabilisticCounts itemStep SAProbabilisticSoftFMeasure = getProbabilisticCounts itemStep SASoft2DFMeasure = getSoft2DCounts @@ -355,6 +363,15 @@ getSoft2DCounts (expected, got) = (tpArea, expArea, gotArea) expArea = totalArea expected gotArea = totalArea got +getFragCounts :: CoverableEntityWithProbability e => ([BareEntity e], [e]) -> (Double, Double, Int, Int) +getFragCounts (expected, got) + | allDisjoint (Prelude.map getBareEntity got) = ( + recallScoreTotal expected got, + precisionScoreTotal got expected, + Prelude.length expected, + Prelude.length got) + | otherwise = (0, 0, Prelude.length expected, Prelude.length got) + countHitsAndTotals :: ([Text], [Text]) -> (Int, Int) countHitsAndTotals (es, os) = if Prelude.length os /= Prelude.length es diff --git a/src/GEval/PrecisionRecall.hs b/src/GEval/PrecisionRecall.hs index b754ea8..15dd33e 100644 --- a/src/GEval/PrecisionRecall.hs +++ b/src/GEval/PrecisionRecall.hs @@ -3,10 +3,11 @@ module GEval.PrecisionRecall(calculateMAPForOneResult, weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall, - fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder, + fMeasureOnCounts, fMeasureOnFragCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder, precisionAndRecall, precisionAndRecallFromCounts, maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching, - getProbabilisticCounts) + getProbabilisticCounts, + countFragFolder) where import GEval.Common @@ -59,13 +60,21 @@ f1MeasureOnCounts = fMeasureOnCounts 1.0 fMeasureOnCounts :: (ConvertibleToDouble n, Integral v) => Double -> (n, v, v) -> Double fMeasureOnCounts beta (tp, nbExpected, nbGot) = - (1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r) - where betaSquared = beta ^ 2 - (p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot) + weightedHarmonicMean beta p r + where (p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot) + +fMeasureOnFragCounts :: (ConvertibleToDouble n, Integral v) => Double -> (n, n, v, v) -> Double +fMeasureOnFragCounts beta (rC, pC, nbExpected, nbGot) = + weightedHarmonicMean beta p r + where r = rC /. nbExpected + p = pC /. nbGot countFolder :: (Num n, Num v) => (n, v, v) -> (n, v, v) -> (n, v, v) countFolder (a1, a2, a3) (b1, b2, b3) = (a1+b1, a2+b2, a3+b3) +countFragFolder :: (Num n, Num v) => (n, n, v, v) -> (n, n, v, v) -> (n, n, v, v) +countFragFolder (a1, a2, a3, a4) (b1, b2, b3, b4) = (a1+b1, a2+b2, a3+b3, a4+b4) + getCounts :: (a -> b -> Bool) -> ([a], [b]) -> (Int, Int, Int) getCounts matchingFun (expected, got) = (maxMatch matchingFun expected got, length expected, diff --git a/src/GEval/Probability.hs b/src/GEval/Probability.hs index c938953..c770bfa 100644 --- a/src/GEval/Probability.hs +++ b/src/GEval/Probability.hs @@ -1,8 +1,11 @@ {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE FlexibleContexts #-} module GEval.Probability (Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero, - EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity) + Coverable, coveredScore, coveredScoreTotal, disjoint, allDisjoint, CoverableEntityWithProbability, + EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity, + precisionScoreTotal, recallScoreTotal) where newtype Probability = P { getP :: Double } @@ -30,3 +33,24 @@ class EntityWithProbability e where getProbabilityAsDouble = getP . getProbability getBareEntity :: e -> BareEntity e matchScore :: BareEntity e -> e -> Double + +class Coverable b where + coveredScore :: b -> b -> Double + coveredScoreTotal :: [b] -> [b] -> Double + coveredScoreTotal items otherItems = sum $ [coveredScore b b' | b <- items, b' <- otherItems] + disjoint :: b -> b -> Bool + + +allDisjoint :: Coverable b => [b] -> Bool +allDisjoint [] = True +allDisjoint (head:tail) = all (disjoint head) tail && allDisjoint tail + +class (Show e, EntityWithProbability e, Coverable (BareEntity e)) => CoverableEntityWithProbability e where + precisionScore :: e -> BareEntity e -> Double + precisionScore got expected = coveredScore (getBareEntity got) expected + precisionScoreTotal :: [e] -> [BareEntity e] -> Double + precisionScoreTotal gots expecteds = coveredScoreTotal (map getBareEntity gots) expecteds + recallScore :: BareEntity e -> e -> Double + recallScore expected got = coveredScore expected (getBareEntity got) + recallScoreTotal :: [BareEntity e] -> [e] -> Double + recallScoreTotal expecteds gots = coveredScoreTotal expecteds (map getBareEntity gots)