Implement Probabilistic-Soft-F1

This commit is contained in:
Filip Graliński 2019-03-12 09:14:50 +01:00 committed by Filip Gralinski
parent 8393bec3ae
commit ae27029f61
13 changed files with 105 additions and 23 deletions

View File

@ -87,7 +87,6 @@ library
, vector-algorithms , vector-algorithms
, aeson , aeson
, aeson-pretty , aeson-pretty
, numeric-tools
, integration , integration
default-language: Haskell2010 default-language: Haskell2010

View File

@ -2,7 +2,6 @@ module Data.Statistics.Calibration
(calibration, softCalibration) where (calibration, softCalibration) where
import Data.Statistics.Loess(loess) import Data.Statistics.Loess(loess)
import Numeric.Tools.Integration
import Numeric.Integration.TanhSinh import Numeric.Integration.TanhSinh
import Data.List(minimum, maximum) import Data.List(minimum, maximum)
import qualified Data.Vector.Unboxed as DVU import qualified Data.Vector.Unboxed as DVU

View File

@ -1,7 +1,9 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
module GEval.Annotation module GEval.Annotation
(parseAnnotations, Annotation(..), matchScore) (parseAnnotations, Annotation(..),
parseObtainedAnnotations, ObtainedAnnotation(..),
matchScore, getProbabilisticSoftCounts)
where where
import qualified Data.IntSet as IS import qualified Data.IntSet as IS
@ -12,10 +14,32 @@ import Data.Attoparsec.Combinator
import Control.Applicative import Control.Applicative
import GEval.Common (sepByWhitespaces, (/.)) import GEval.Common (sepByWhitespaces, (/.))
import Data.Char import Data.Char
import Data.List (find)
import Data.Maybe (fromMaybe)
import GEval.PrecisionRecall(weightedMaxMatching)
data Annotation = Annotation T.Text IS.IntSet data Annotation = Annotation T.Text IS.IntSet
deriving (Eq, Show) deriving (Eq, Show)
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
deriving (Eq, Show)
getProb :: ObtainedAnnotation -> Double
getProb (ObtainedAnnotation _ p) = p
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
obtainedAnnotationsParser :: Parser [ObtainedAnnotation]
obtainedAnnotationsParser = sepByWhitespaces obtainedAnnotationParser
obtainedAnnotationParser :: Parser ObtainedAnnotation
obtainedAnnotationParser = do
annotation <- annotationParser
mProb <- (string ":" *> double) <|> (return 1.0)
return $ ObtainedAnnotation annotation mProb
parseAnnotations :: T.Text -> Either String [Annotation] parseAnnotations :: T.Text -> Either String [Annotation]
parseAnnotations t = parseOnly (annotationsParser <* endOfInput) t parseAnnotations t = parseOnly (annotationsParser <* endOfInput) t
@ -38,9 +62,20 @@ intervalParser = do
endIx <- (string "-" *> decimal <|> pure startIx) endIx <- (string "-" *> decimal <|> pure startIx)
pure $ IS.fromList [startIx..endIx] pure $ IS.fromList [startIx..endIx]
matchScore :: Annotation -> Annotation -> Double matchScore :: Annotation -> ObtainedAnnotation -> Double
matchScore (Annotation labelA intSetA) (Annotation labelB intSetB) matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB) | labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
| otherwise = 0.0 | otherwise = 0.0
where intSetLength = Prelude.length . IS.toList where intSetLength = Prelude.length . IS.toList
intersect = intSetA `IS.intersection` intSetB intersect = intSetA `IS.intersection` intSetB
getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int)
getProbabilisticSoftCounts (expected, got) = (results, (map getProb got),
gotMass,
length expected)
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching
results = map findResult [1..(length got)]
findResult j = case find (\(i, j') -> j' == j) $ matching of
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
Nothing -> 0.0
(matching, _) = weightedMaxMatching matchScore expected got

View File

@ -100,6 +100,8 @@ import qualified Data.Vector.Generic as VG
import Statistics.Correlation import Statistics.Correlation
import Data.Statistics.Calibration(softCalibration)
import Data.Proxy import Data.Proxy
import Data.Word import Data.Word
@ -115,7 +117,7 @@ data Metric = RMSE | MSE | Pearson | Spearman | BLEU | GLEU | WER | Accuracy | C
| LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood | LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood
| BIOF1 | BIOF1Labels | TokenAccuracy | LikelihoodHashed Word32 | MAE | SMAPE | MultiLabelFMeasure Double | BIOF1 | BIOF1Labels | TokenAccuracy | LikelihoodHashed Word32 | MAE | SMAPE | MultiLabelFMeasure Double
| MultiLabelLogLoss | MultiLabelLikelihood | MultiLabelLogLoss | MultiLabelLikelihood
| SoftFMeasure Double | SoftFMeasure Double | ProbabilisticSoftFMeasure Double
deriving (Eq) deriving (Eq)
instance Show Metric where instance Show Metric where
@ -131,6 +133,7 @@ instance Show Metric where
show (FMeasure beta) = "F" ++ (show beta) show (FMeasure beta) = "F" ++ (show beta)
show (MacroFMeasure beta) = "Macro-F" ++ (show beta) show (MacroFMeasure beta) = "Macro-F" ++ (show beta)
show (SoftFMeasure beta) = "Soft-F" ++ (show beta) show (SoftFMeasure beta) = "Soft-F" ++ (show beta)
show (ProbabilisticSoftFMeasure beta) = "Probabilistic-Soft-F" ++ (show beta)
show NMI = "NMI" show NMI = "NMI"
show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if
nbOfBits == defaultLogLossHashedSize nbOfBits == defaultLogLossHashedSize
@ -180,6 +183,9 @@ instance Read Metric where
readsPrec p ('S':'o':'f':'t':'-':'F':theRest) = case readsPrec p theRest of readsPrec p ('S':'o':'f':'t':'-':'F':theRest) = case readsPrec p theRest of
[(beta, theRest)] -> [(SoftFMeasure beta, theRest)] [(beta, theRest)] -> [(SoftFMeasure beta, theRest)]
_ -> [] _ -> []
readsPrec p ('P':'r':'o':'b':'a':'b':'i':'l':'i':'s':'t':'i':'c':'-':'S':'o':'f':'t':'-':'F':theRest) = case readsPrec p theRest of
[(beta, theRest)] -> [(ProbabilisticSoftFMeasure beta, theRest)]
_ -> []
readsPrec p ('L':'o':'g':'L':'o':'s':'s':'H':'a':'s':'h':'e':'d':theRest) = case readsPrec p theRest of readsPrec p ('L':'o':'g':'L':'o':'s':'s':'H':'a':'s':'h':'e':'d':theRest) = case readsPrec p theRest of
[(nbOfBits, theRest)] -> [(LogLossHashed nbOfBits, theRest)] [(nbOfBits, theRest)] -> [(LogLossHashed nbOfBits, theRest)]
_ -> [(LogLossHashed defaultLogLossHashedSize, theRest)] _ -> [(LogLossHashed defaultLogLossHashedSize, theRest)]
@ -216,6 +222,7 @@ getMetricOrdering ClippEU = TheHigherTheBetter
getMetricOrdering (FMeasure _) = TheHigherTheBetter getMetricOrdering (FMeasure _) = TheHigherTheBetter
getMetricOrdering (MacroFMeasure _) = TheHigherTheBetter getMetricOrdering (MacroFMeasure _) = TheHigherTheBetter
getMetricOrdering (SoftFMeasure _) = TheHigherTheBetter getMetricOrdering (SoftFMeasure _) = TheHigherTheBetter
getMetricOrdering (ProbabilisticSoftFMeasure _) = TheHigherTheBetter
getMetricOrdering NMI = TheHigherTheBetter getMetricOrdering NMI = TheHigherTheBetter
getMetricOrdering (LogLossHashed _) = TheLowerTheBetter getMetricOrdering (LogLossHashed _) = TheLowerTheBetter
getMetricOrdering (LikelihoodHashed _) = TheHigherTheBetter getMetricOrdering (LikelihoodHashed _) = TheHigherTheBetter
@ -721,7 +728,7 @@ gevalCore' (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip)
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
parseAnnotations parseObtainedAnnotations
getSoftCounts getSoftCounts
countAgg countAgg
(fMeasureOnCounts beta) (fMeasureOnCounts beta)
@ -729,6 +736,19 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
Prelude.length expected, Prelude.length expected,
Prelude.length got) Prelude.length got)
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
parseObtainedAnnotations
getProbabilisticSoftCounts
probabilisticSoftAgg
(fMeasureOnProbabilisticCounts beta)
where probabilisticSoftAgg :: Monad m => ConduitM ([Double], [Double], Double, Int) o m ([Double], [Double], Double, Int)
probabilisticSoftAgg = CC.foldl probabilisticSoftFolder ([], [], fromInteger 0, 0)
probabilisticSoftFolder (r1, p1, g1, e1) (r2, p2, g2, e2) = (r1 ++ r2, p1 ++ p2, g1 + g2, e1 + e2)
fMeasureOnProbabilisticCounts :: Double -> ([Double], [Double], Double, Int) -> Double
fMeasureOnProbabilisticCounts beta (results, probs, got, nbExpected) = weightedHarmonicMean beta calibrationMeasure recall
where calibrationMeasure = softCalibration results probs
recall = got /. nbExpected
gevalCore' ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep gevalCore' ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep
where where
parseClippings = controlledParse lineClippingsParser parseClippings = controlledParse lineClippingsParser

View File

@ -219,7 +219,7 @@ metricReader = many $ option auto -- actually `some` should be used inst
( long "metric" -- --metric might be in the config.txt file... ( long "metric" -- --metric might be in the config.txt file...
<> short 'm' <> short 'm'
<> metavar "METRIC" <> metavar "METRIC"
<> help "Metric to be used - RMSE, MSE, MAE, SMAPE, Pearson, Spearman, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MultiLabel-Likelihood, MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels, TokenAccuracy, soft F-measure (specify as Soft-F1, Soft-F2, Soft-F0.25) or CharMatch" ) <> help "Metric to be used - RMSE, MSE, MAE, SMAPE, Pearson, Spearman, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MultiLabel-Likelihood, MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels, TokenAccuracy, soft F-measure (specify as Soft-F1, Soft-F2, Soft-F0.25), probabilistic soft F-measure (specify as Probabilistic-Soft-F1, Probabilistic-Soft-F2, Probabilistic-Soft-F0.25) or CharMatch" )
altMetricReader :: Parser (Maybe Metric) altMetricReader :: Parser (Maybe Metric)
altMetricReader = optional $ option auto altMetricReader = optional $ option auto

View File

@ -2,10 +2,10 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
module GEval.PrecisionRecall(calculateMAPForOneResult, module GEval.PrecisionRecall(calculateMAPForOneResult,
fMeasure, f1Measure, f2Measure, precision, recall, weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder, fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
precisionAndRecall, precisionAndRecallFromCounts, precisionAndRecall, precisionAndRecallFromCounts,
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch) maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching)
where where
import GEval.Common import GEval.Common
@ -41,10 +41,13 @@ fMeasure :: Double -- ^ beta parameter
-> [a] -- ^ the ground truth -> [a] -- ^ the ground truth
-> [b] -- ^ what we got -> [b] -- ^ what we got
-> Double -- ^ f-Measure -> Double -- ^ f-Measure
fMeasure beta matchingFun expected got = fMeasure beta matchingFun expected got = weightedHarmonicMean beta p r
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r) where (p, r) = precisionAndRecall matchingFun expected got
weightedHarmonicMean :: Double -> Double -> Double -> Double
weightedHarmonicMean beta x y =
(1 + betaSquared) * x * y `safeDoubleDiv` (betaSquared * x + y)
where betaSquared = beta ^ 2 where betaSquared = beta ^ 2
(p, r) = precisionAndRecall matchingFun expected got
f2MeasureOnCounts :: ConvertibleToDouble n => (n, Int, Int) -> Double f2MeasureOnCounts :: ConvertibleToDouble n => (n, Int, Int) -> Double
f2MeasureOnCounts = fMeasureOnCounts 2.0 f2MeasureOnCounts = fMeasureOnCounts 2.0
@ -125,11 +128,14 @@ buildGraph matchFun expected got = (b, e, g)
-- the weight are assumed to be between 0.0 and 1.0 -- the weight are assumed to be between 0.0 and 1.0
weightedMaxMatch :: (a -> b -> Double) -> [a] -> [b] -> Double weightedMaxMatch :: (a -> b -> Double) -> [a] -> [b] -> Double
weightedMaxMatch matchFun expected got = (fromIntegral $ length matching) - score weightedMaxMatch matchFun expected got = (fromIntegral $ length matching) - score
where (matching, score) = hungarianMethodDouble complementWeightArray where (matching, score) = weightedMaxMatching matchFun expected got
-- unfortunately `hungarianMethodDouble` looks
-- for minimal bipartite matching weightedMaxMatching :: (a -> b -> Double) -> [a] -> [b] -> ([(Int, Int)], Double)
-- rather than the maximal one weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeightArray
complementWeightArray = DAI.array ((1, 1), (m, n)) weightList -- unfortunately `hungarianMethodDouble` looks
-- for minimal bipartite matching
-- rather than the maximal one
where complementWeightArray = DAI.array ((1, 1), (m, n)) weightList
m = length expected m = length expected
n = length got n = length got
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected, weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,

View File

@ -248,6 +248,11 @@ main = hspec $ do
runGEvalTest "soft-f1-simple" `shouldReturnAlmost` 0.33333333333333 runGEvalTest "soft-f1-simple" `shouldReturnAlmost` 0.33333333333333
it "perfect test" $ do it "perfect test" $ do
runGEvalTest "soft-f1-perfect" `shouldReturnAlmost` 1.0 runGEvalTest "soft-f1-perfect" `shouldReturnAlmost` 1.0
describe "Probabilistic-Soft-F1" $ do
it "simple test" $ do
runGEvalTest "probabilistic-soft-f1-simple" `shouldReturnAlmost` 0.33333333333333
it "simple test with perfect calibration" $ do
runGEvalTest "probabilistic-soft-f1-calibrated" `shouldReturnAlmost` 0.88888888888
describe "test edit-distance library" $ do describe "test edit-distance library" $ do
it "for handling UTF8" $ do it "for handling UTF8" $ do
levenshteinDistance defaultEditCosts "źdźbło" "źd好bło" `shouldBe` 1 levenshteinDistance defaultEditCosts "źdźbło" "źd好bło" `shouldBe` 1
@ -301,16 +306,16 @@ main = hspec $ do
describe "Annotation format" $ do describe "Annotation format" $ do
it "just parse" $ do it "just parse" $ do
parseAnnotations "foo:3,7-10 baz:4-6" `shouldBe` Right [Annotation "foo" (IS.fromList [3,7,8,9,10]), parseAnnotations "foo:3,7-10 baz:4-6" `shouldBe` Right [Annotation "foo" (IS.fromList [3,7,8,9,10]),
Annotation "baz" (IS.fromList [4,5,6])] Annotation "baz" (IS.fromList [4,5,6])]
it "empty" $ do it "empty" $ do
parseAnnotations "" `shouldBe` Right [] parseAnnotations "" `shouldBe` Right []
it "empty (just spaces)" $ do it "empty (just spaces)" $ do
parseAnnotations " " `shouldBe` Right [] parseAnnotations " " `shouldBe` Right []
it "match score" $ do it "match score" $ do
matchScore (Annotation "x" (IS.fromList [3..6])) (Annotation "y" (IS.fromList [3..6])) `shouldBeAlmost` 0.0 matchScore (Annotation "x" (IS.fromList [3..6])) (ObtainedAnnotation (Annotation "y" (IS.fromList [3..6])) 1.0) `shouldBeAlmost` 0.0
matchScore (Annotation "x" (IS.fromList [3..6])) (Annotation "x" (IS.fromList [3..6])) `shouldBeAlmost` 1.0 matchScore (Annotation "x" (IS.fromList [3..6])) (ObtainedAnnotation (Annotation "x" (IS.fromList [3..6])) 1.0) `shouldBeAlmost` 1.0
matchScore (Annotation "x" (IS.fromList [123..140])) (Annotation "x" (IS.fromList [125..130])) `shouldBeAlmost` 0.33333 matchScore (Annotation "x" (IS.fromList [123..140])) (ObtainedAnnotation (Annotation "x" (IS.fromList [125..130])) 1.0) `shouldBeAlmost` 0.33333
matchScore (Annotation "x" (IS.fromList [3..4])) (Annotation "x" (IS.fromList [2..13])) `shouldBeAlmost` 0.1666666 matchScore (Annotation "x" (IS.fromList [3..4])) (ObtainedAnnotation (Annotation "x" (IS.fromList [2..13])) 1.0) `shouldBeAlmost` 0.1666666
describe "BIO format" $ do describe "BIO format" $ do
it "just parse" $ do it "just parse" $ do
let (Right r) = parseOnly (bioSequenceParser <* endOfInput) "O B-city/NEW_YORK I-city B-city/KALISZ I-city O B-name" let (Right r) = parseOnly (bioSequenceParser <* endOfInput) "O B-city/NEW_YORK I-city B-city/KALISZ I-city O B-name"

View File

@ -0,0 +1,5 @@
foo:1:0.8
foo:1:0.8
foo:1:0.8
bar:1:0.8
foo:1:0.8
1 foo:1:0.8
2 foo:1:0.8
3 foo:1:0.8
4 bar:1:0.8
5 foo:1:0.8

View File

@ -0,0 +1 @@
--metric Probabilistic-Soft-F1

View File

@ -0,0 +1,5 @@
foo:1
foo:1
bar:1
foo:1
1 foo:1
2 foo:1
3 bar:1
4 foo:1

View File

@ -0,0 +1,3 @@
foo:2-13:1.0
bar:3-7:1.0
baz:10-16:1.0 foo:123-140:1.0
1 foo:2-13:1.0
2 bar:3-7:1.0
3 baz:10-16:1.0 foo:123-140:1.0

View File

@ -0,0 +1 @@
--metric Probabilistic-Soft-F1

View File

@ -0,0 +1,3 @@
foo:3-4 baz:3-7 foo:12
baz:10-16 foo:125-130
1 foo:3-4 baz:3-7 foo:12
2 baz:10-16 foo:125-130