Define FLC-F-score
This commit is contained in:
parent
00b02680cc
commit
169931be8f
@ -28,6 +28,15 @@ data Annotation = Annotation T.Text IS.IntSet
|
||||
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance Coverable Annotation where
|
||||
coveredScore (Annotation labelA intSetA) (Annotation labelB intSetB)
|
||||
| labelA == labelB = (intSetLength intersect) /. (intSetLength intSetA)
|
||||
where intSetLength = Prelude.length . IS.toList
|
||||
intersect = intSetA `IS.intersection` intSetB
|
||||
disjoint (Annotation labelA intSetA) (Annotation labelB intSetB)
|
||||
| labelA == labelB = intSetA `IS.disjoint` intSetB
|
||||
| otherwise = True
|
||||
|
||||
instance EntityWithProbability ObtainedAnnotation where
|
||||
type BareEntity ObtainedAnnotation = Annotation
|
||||
getBareEntity (ObtainedAnnotation annotation _) = annotation
|
||||
@ -38,6 +47,7 @@ instance EntityWithProbability ObtainedAnnotation where
|
||||
where intSetLength = Prelude.length . IS.toList
|
||||
intersect = intSetA `IS.intersection` intSetB
|
||||
|
||||
instance CoverableEntityWithProbability ObtainedAnnotation
|
||||
|
||||
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
|
||||
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
|
||||
|
@ -638,6 +638,9 @@ generalizedProbabilisticFMeasure beta metric = gevalCoreWithoutInput metric
|
||||
countAgg :: (Num n, Num v, Monad m) => ConduitM (n, v, v) o m (n, v, v)
|
||||
countAgg = CC.foldl countFolder (fromInteger 0, fromInteger 0, fromInteger 0)
|
||||
|
||||
countFragAgg :: (Num n, Num v, Monad m) => ConduitM (n, n, v, v) o m (n, n, v, v)
|
||||
countFragAgg = CC.foldl countFragFolder (fromInteger 0, fromInteger 0, fromInteger 0, fromInteger 0)
|
||||
|
||||
gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
|
||||
(V.Vector (Double, Double) -> Double) -> -- ^ correlation function
|
||||
LineSource (ResourceT m) -> -- ^ source to read the expected output
|
||||
@ -813,6 +816,10 @@ continueGEvalCalculations SASoftFMeasure (SoftFMeasure beta) = defineContinuatio
|
||||
(fMeasureOnCounts beta)
|
||||
noGraph
|
||||
|
||||
continueGEvalCalculations SAFLCFMeasure (FLCFMeasure beta) = defineContinuation countFragAgg
|
||||
(fMeasureOnFragCounts beta)
|
||||
noGraph
|
||||
|
||||
continueGEvalCalculations SASoft2DFMeasure (Soft2DFMeasure beta) = defineContinuation (CC.map (fMeasureOnCounts beta) .| averageC)
|
||||
id
|
||||
noGraph
|
||||
|
@ -31,6 +31,7 @@ data Metric = RMSE | MSE | Pearson | Spearman | BLEU | GLEU | WER | Accuracy | C
|
||||
| MultiLabelLogLoss | MultiLabelLikelihood
|
||||
| SoftFMeasure Double | ProbabilisticMultiLabelFMeasure Double
|
||||
| ProbabilisticSoftFMeasure Double | Soft2DFMeasure Double
|
||||
| FLCFMeasure Double
|
||||
-- it would be better to avoid infinite recursion here
|
||||
-- `Mean (Mean BLEU)` is not useful, but as it would mean
|
||||
-- a larger refactor, we will postpone this
|
||||
@ -53,6 +54,7 @@ instance Show Metric where
|
||||
show (ProbabilisticMultiLabelFMeasure beta) = "Probabilistic-MultiLabel-F" ++ (show beta)
|
||||
show (ProbabilisticSoftFMeasure beta) = "Probabilistic-Soft-F" ++ (show beta)
|
||||
show (Soft2DFMeasure beta) = "Soft2D-F" ++ (show beta)
|
||||
show (FLCFMeasure beta) = "FLC-F" ++ (show beta)
|
||||
show NMI = "NMI"
|
||||
show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if
|
||||
nbOfBits == defaultLogLossHashedSize
|
||||
@ -95,6 +97,9 @@ instance Read Metric where
|
||||
readsPrec _ ('A':'c':'c':'u':'r':'a':'c':'y':theRest) = [(Accuracy, theRest)]
|
||||
readsPrec _ ('C':'l':'i':'p':'p':'E':'U':theRest) = [(ClippEU, theRest)]
|
||||
readsPrec _ ('N':'M':'I':theRest) = [(NMI, theRest)]
|
||||
readsPrec p ('F':'L':'C':'-':'F':theRest) = case readsPrec p theRest of
|
||||
[(beta, theRest)] -> [(FLCFMeasure beta, theRest)]
|
||||
_ -> []
|
||||
readsPrec p ('F':theRest) = case readsPrec p theRest of
|
||||
[(beta, theRest)] -> [(FMeasure beta, theRest)]
|
||||
_ -> []
|
||||
@ -156,6 +161,7 @@ getMetricOrdering (SoftFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (ProbabilisticMultiLabelFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (ProbabilisticSoftFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (Soft2DFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (FLCFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering NMI = TheHigherTheBetter
|
||||
getMetricOrdering (LogLossHashed _) = TheLowerTheBetter
|
||||
getMetricOrdering (LikelihoodHashed _) = TheHigherTheBetter
|
||||
|
@ -50,6 +50,7 @@ singletons [d|data AMetric = ARMSE | AMSE | APearson | ASpearman | ABLEU | AGLEU
|
||||
| ABIOF1 | ABIOF1Labels | ATokenAccuracy | ASegmentAccuracy | ALikelihoodHashed | AMAE | ASMAPE | AMultiLabelFMeasure
|
||||
| AMultiLabelLogLoss | AMultiLabelLikelihood
|
||||
| ASoftFMeasure | AProbabilisticMultiLabelFMeasure | AProbabilisticSoftFMeasure | ASoft2DFMeasure
|
||||
| AFLCFMeasure
|
||||
deriving (Eq)
|
||||
|]
|
||||
|
||||
@ -83,6 +84,7 @@ toHelper (MultiLabelFMeasure _) = AMultiLabelFMeasure
|
||||
toHelper MultiLabelLogLoss = AMultiLabelLogLoss
|
||||
toHelper MultiLabelLikelihood = AMultiLabelLikelihood
|
||||
toHelper (SoftFMeasure _) = ASoftFMeasure
|
||||
toHelper (FLCFMeasure _) = AFLCFMeasure
|
||||
toHelper (ProbabilisticMultiLabelFMeasure _) = AProbabilisticMultiLabelFMeasure
|
||||
toHelper (ProbabilisticSoftFMeasure _) = AProbabilisticSoftFMeasure
|
||||
toHelper (Soft2DFMeasure _) = ASoft2DFMeasure
|
||||
@ -104,6 +106,7 @@ type family ParsedExpectedType (t :: AMetric) :: * where
|
||||
ParsedExpectedType AFMeasure = Bool
|
||||
ParsedExpectedType AMacroFMeasure = Maybe Text
|
||||
ParsedExpectedType ASoftFMeasure = [Annotation]
|
||||
ParsedExpectedType AFLCFMeasure = [Annotation]
|
||||
ParsedExpectedType AProbabilisticMultiLabelFMeasure = [Text]
|
||||
ParsedExpectedType AProbabilisticSoftFMeasure = [Annotation]
|
||||
ParsedExpectedType ASoft2DFMeasure = [LabeledClipping]
|
||||
@ -137,6 +140,7 @@ expectedParser SAClippEU = controlledParse lineClippingSpecsParser
|
||||
expectedParser SAFMeasure = zeroOneParser
|
||||
expectedParser SAMacroFMeasure = justStrip
|
||||
expectedParser SASoftFMeasure = parseAnnotations
|
||||
expectedParser SAFLCFMeasure = parseAnnotations
|
||||
expectedParser SAProbabilisticMultiLabelFMeasure = intoWords
|
||||
expectedParser SAProbabilisticSoftFMeasure = parseAnnotations
|
||||
expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser
|
||||
@ -163,6 +167,7 @@ type family ParsedOutputType (t :: AMetric) :: * where
|
||||
ParsedOutputType AClippEU = [Clipping]
|
||||
ParsedOutputType AMacroFMeasure = Maybe Text
|
||||
ParsedOutputType ASoftFMeasure = [ObtainedAnnotation]
|
||||
ParsedOutputType AFLCFMeasure = [ObtainedAnnotation]
|
||||
ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation]
|
||||
ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb]
|
||||
ParsedOutputType AMultiLabelLikelihood = ProbList
|
||||
@ -182,6 +187,7 @@ outputParser SAClippEU = controlledParse lineClippingsParser
|
||||
outputParser SAFMeasure = probToZeroOneParser
|
||||
outputParser SAMacroFMeasure = Right . predictedParser . strip
|
||||
outputParser SASoftFMeasure = parseObtainedAnnotations
|
||||
outputParser SAFLCFMeasure = parseObtainedAnnotations
|
||||
outputParser SAProbabilisticMultiLabelFMeasure = (Right . (\(ProbList es) -> es) . parseIntoProbList)
|
||||
outputParser SAProbabilisticSoftFMeasure = parseObtainedAnnotations
|
||||
outputParser SASoft2DFMeasure = expectedParser SASoft2DFMeasure
|
||||
@ -208,6 +214,7 @@ type family ItemIntermediateRepresentationType (t :: AMetric) :: * where
|
||||
ItemIntermediateRepresentationType AFMeasure = (Int, Int, Int)
|
||||
ItemIntermediateRepresentationType AMacroFMeasure = (Maybe Text, Maybe Text, Maybe Text)
|
||||
ItemIntermediateRepresentationType ASoftFMeasure = (Double, Int, Int)
|
||||
ItemIntermediateRepresentationType AFLCFMeasure = (Double, Double, Int, Int)
|
||||
ItemIntermediateRepresentationType ASoft2DFMeasure = (Integer, Integer, Integer)
|
||||
ItemIntermediateRepresentationType AClippEU = (Int, Int, Int)
|
||||
ItemIntermediateRepresentationType ANMI = (Text, Text)
|
||||
@ -238,6 +245,7 @@ itemStep SAClippEU = clippEUMatchStep
|
||||
itemStep SAFMeasure = getCount
|
||||
itemStep SAMacroFMeasure = getClassesInvolved
|
||||
itemStep SASoftFMeasure = getSoftCounts
|
||||
itemStep SAFLCFMeasure = getFragCounts
|
||||
itemStep SAProbabilisticMultiLabelFMeasure = getProbabilisticCounts
|
||||
itemStep SAProbabilisticSoftFMeasure = getProbabilisticCounts
|
||||
itemStep SASoft2DFMeasure = getSoft2DCounts
|
||||
@ -355,6 +363,15 @@ getSoft2DCounts (expected, got) = (tpArea, expArea, gotArea)
|
||||
expArea = totalArea expected
|
||||
gotArea = totalArea got
|
||||
|
||||
getFragCounts :: CoverableEntityWithProbability e => ([BareEntity e], [e]) -> (Double, Double, Int, Int)
|
||||
getFragCounts (expected, got)
|
||||
| allDisjoint (Prelude.map getBareEntity got) = (
|
||||
recallScoreTotal expected got,
|
||||
precisionScoreTotal got expected,
|
||||
Prelude.length expected,
|
||||
Prelude.length got)
|
||||
| otherwise = (0, 0, Prelude.length expected, Prelude.length got)
|
||||
|
||||
countHitsAndTotals :: ([Text], [Text]) -> (Int, Int)
|
||||
countHitsAndTotals (es, os) =
|
||||
if Prelude.length os /= Prelude.length es
|
||||
|
@ -3,10 +3,11 @@
|
||||
|
||||
module GEval.PrecisionRecall(calculateMAPForOneResult,
|
||||
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
|
||||
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
|
||||
fMeasureOnCounts, fMeasureOnFragCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
|
||||
precisionAndRecall, precisionAndRecallFromCounts,
|
||||
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching,
|
||||
getProbabilisticCounts)
|
||||
getProbabilisticCounts,
|
||||
countFragFolder)
|
||||
where
|
||||
|
||||
import GEval.Common
|
||||
@ -59,13 +60,21 @@ f1MeasureOnCounts = fMeasureOnCounts 1.0
|
||||
|
||||
fMeasureOnCounts :: (ConvertibleToDouble n, Integral v) => Double -> (n, v, v) -> Double
|
||||
fMeasureOnCounts beta (tp, nbExpected, nbGot) =
|
||||
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||
where betaSquared = beta ^ 2
|
||||
(p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot)
|
||||
weightedHarmonicMean beta p r
|
||||
where (p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot)
|
||||
|
||||
fMeasureOnFragCounts :: (ConvertibleToDouble n, Integral v) => Double -> (n, n, v, v) -> Double
|
||||
fMeasureOnFragCounts beta (rC, pC, nbExpected, nbGot) =
|
||||
weightedHarmonicMean beta p r
|
||||
where r = rC /. nbExpected
|
||||
p = pC /. nbGot
|
||||
|
||||
countFolder :: (Num n, Num v) => (n, v, v) -> (n, v, v) -> (n, v, v)
|
||||
countFolder (a1, a2, a3) (b1, b2, b3) = (a1+b1, a2+b2, a3+b3)
|
||||
|
||||
countFragFolder :: (Num n, Num v) => (n, n, v, v) -> (n, n, v, v) -> (n, n, v, v)
|
||||
countFragFolder (a1, a2, a3, a4) (b1, b2, b3, b4) = (a1+b1, a2+b2, a3+b3, a4+b4)
|
||||
|
||||
getCounts :: (a -> b -> Bool) -> ([a], [b]) -> (Int, Int, Int)
|
||||
getCounts matchingFun (expected, got) = (maxMatch matchingFun expected got,
|
||||
length expected,
|
||||
|
@ -1,8 +1,11 @@
|
||||
{-# LANGUAGE TypeFamilies #-}
|
||||
{-# LANGUAGE FlexibleContexts #-}
|
||||
|
||||
module GEval.Probability
|
||||
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero,
|
||||
EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity)
|
||||
Coverable, coveredScore, coveredScoreTotal, disjoint, allDisjoint, CoverableEntityWithProbability,
|
||||
EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity,
|
||||
precisionScoreTotal, recallScoreTotal)
|
||||
where
|
||||
|
||||
newtype Probability = P { getP :: Double }
|
||||
@ -30,3 +33,24 @@ class EntityWithProbability e where
|
||||
getProbabilityAsDouble = getP . getProbability
|
||||
getBareEntity :: e -> BareEntity e
|
||||
matchScore :: BareEntity e -> e -> Double
|
||||
|
||||
class Coverable b where
|
||||
coveredScore :: b -> b -> Double
|
||||
coveredScoreTotal :: [b] -> [b] -> Double
|
||||
coveredScoreTotal items otherItems = sum $ [coveredScore b b' | b <- items, b' <- otherItems]
|
||||
disjoint :: b -> b -> Bool
|
||||
|
||||
|
||||
allDisjoint :: Coverable b => [b] -> Bool
|
||||
allDisjoint [] = True
|
||||
allDisjoint (head:tail) = all (disjoint head) tail && allDisjoint tail
|
||||
|
||||
class (Show e, EntityWithProbability e, Coverable (BareEntity e)) => CoverableEntityWithProbability e where
|
||||
precisionScore :: e -> BareEntity e -> Double
|
||||
precisionScore got expected = coveredScore (getBareEntity got) expected
|
||||
precisionScoreTotal :: [e] -> [BareEntity e] -> Double
|
||||
precisionScoreTotal gots expecteds = coveredScoreTotal (map getBareEntity gots) expecteds
|
||||
recallScore :: BareEntity e -> e -> Double
|
||||
recallScore expected got = coveredScore expected (getBareEntity got)
|
||||
recallScoreTotal :: [BareEntity e] -> [e] -> Double
|
||||
recallScoreTotal expecteds gots = coveredScoreTotal expecteds (map getBareEntity gots)
|
||||
|
Loading…
Reference in New Issue
Block a user