Implement Probabilistic-Soft-F1
This commit is contained in:
parent
8393bec3ae
commit
ae27029f61
@ -87,7 +87,6 @@ library
|
||||
, vector-algorithms
|
||||
, aeson
|
||||
, aeson-pretty
|
||||
, numeric-tools
|
||||
, integration
|
||||
default-language: Haskell2010
|
||||
|
||||
|
@ -2,7 +2,6 @@ module Data.Statistics.Calibration
|
||||
(calibration, softCalibration) where
|
||||
|
||||
import Data.Statistics.Loess(loess)
|
||||
import Numeric.Tools.Integration
|
||||
import Numeric.Integration.TanhSinh
|
||||
import Data.List(minimum, maximum)
|
||||
import qualified Data.Vector.Unboxed as DVU
|
||||
|
@ -1,7 +1,9 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module GEval.Annotation
|
||||
(parseAnnotations, Annotation(..), matchScore)
|
||||
(parseAnnotations, Annotation(..),
|
||||
parseObtainedAnnotations, ObtainedAnnotation(..),
|
||||
matchScore, getProbabilisticSoftCounts)
|
||||
where
|
||||
|
||||
import qualified Data.IntSet as IS
|
||||
@ -12,10 +14,32 @@ import Data.Attoparsec.Combinator
|
||||
import Control.Applicative
|
||||
import GEval.Common (sepByWhitespaces, (/.))
|
||||
import Data.Char
|
||||
import Data.List (find)
|
||||
import Data.Maybe (fromMaybe)
|
||||
|
||||
import GEval.PrecisionRecall(weightedMaxMatching)
|
||||
|
||||
data Annotation = Annotation T.Text IS.IntSet
|
||||
deriving (Eq, Show)
|
||||
|
||||
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
|
||||
deriving (Eq, Show)
|
||||
|
||||
getProb :: ObtainedAnnotation -> Double
|
||||
getProb (ObtainedAnnotation _ p) = p
|
||||
|
||||
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
|
||||
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
|
||||
|
||||
obtainedAnnotationsParser :: Parser [ObtainedAnnotation]
|
||||
obtainedAnnotationsParser = sepByWhitespaces obtainedAnnotationParser
|
||||
|
||||
obtainedAnnotationParser :: Parser ObtainedAnnotation
|
||||
obtainedAnnotationParser = do
|
||||
annotation <- annotationParser
|
||||
mProb <- (string ":" *> double) <|> (return 1.0)
|
||||
return $ ObtainedAnnotation annotation mProb
|
||||
|
||||
parseAnnotations :: T.Text -> Either String [Annotation]
|
||||
parseAnnotations t = parseOnly (annotationsParser <* endOfInput) t
|
||||
|
||||
@ -38,9 +62,20 @@ intervalParser = do
|
||||
endIx <- (string "-" *> decimal <|> pure startIx)
|
||||
pure $ IS.fromList [startIx..endIx]
|
||||
|
||||
matchScore :: Annotation -> Annotation -> Double
|
||||
matchScore (Annotation labelA intSetA) (Annotation labelB intSetB)
|
||||
matchScore :: Annotation -> ObtainedAnnotation -> Double
|
||||
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
|
||||
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
|
||||
| otherwise = 0.0
|
||||
where intSetLength = Prelude.length . IS.toList
|
||||
intersect = intSetA `IS.intersection` intSetB
|
||||
|
||||
getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int)
|
||||
getProbabilisticSoftCounts (expected, got) = (results, (map getProb got),
|
||||
gotMass,
|
||||
length expected)
|
||||
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching
|
||||
results = map findResult [1..(length got)]
|
||||
findResult j = case find (\(i, j') -> j' == j) $ matching of
|
||||
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
|
||||
Nothing -> 0.0
|
||||
(matching, _) = weightedMaxMatching matchScore expected got
|
||||
|
@ -100,6 +100,8 @@ import qualified Data.Vector.Generic as VG
|
||||
|
||||
import Statistics.Correlation
|
||||
|
||||
import Data.Statistics.Calibration(softCalibration)
|
||||
|
||||
import Data.Proxy
|
||||
|
||||
import Data.Word
|
||||
@ -115,7 +117,7 @@ data Metric = RMSE | MSE | Pearson | Spearman | BLEU | GLEU | WER | Accuracy | C
|
||||
| LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood
|
||||
| BIOF1 | BIOF1Labels | TokenAccuracy | LikelihoodHashed Word32 | MAE | SMAPE | MultiLabelFMeasure Double
|
||||
| MultiLabelLogLoss | MultiLabelLikelihood
|
||||
| SoftFMeasure Double
|
||||
| SoftFMeasure Double | ProbabilisticSoftFMeasure Double
|
||||
deriving (Eq)
|
||||
|
||||
instance Show Metric where
|
||||
@ -131,6 +133,7 @@ instance Show Metric where
|
||||
show (FMeasure beta) = "F" ++ (show beta)
|
||||
show (MacroFMeasure beta) = "Macro-F" ++ (show beta)
|
||||
show (SoftFMeasure beta) = "Soft-F" ++ (show beta)
|
||||
show (ProbabilisticSoftFMeasure beta) = "Probabilistic-Soft-F" ++ (show beta)
|
||||
show NMI = "NMI"
|
||||
show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if
|
||||
nbOfBits == defaultLogLossHashedSize
|
||||
@ -180,6 +183,9 @@ instance Read Metric where
|
||||
readsPrec p ('S':'o':'f':'t':'-':'F':theRest) = case readsPrec p theRest of
|
||||
[(beta, theRest)] -> [(SoftFMeasure beta, theRest)]
|
||||
_ -> []
|
||||
readsPrec p ('P':'r':'o':'b':'a':'b':'i':'l':'i':'s':'t':'i':'c':'-':'S':'o':'f':'t':'-':'F':theRest) = case readsPrec p theRest of
|
||||
[(beta, theRest)] -> [(ProbabilisticSoftFMeasure beta, theRest)]
|
||||
_ -> []
|
||||
readsPrec p ('L':'o':'g':'L':'o':'s':'s':'H':'a':'s':'h':'e':'d':theRest) = case readsPrec p theRest of
|
||||
[(nbOfBits, theRest)] -> [(LogLossHashed nbOfBits, theRest)]
|
||||
_ -> [(LogLossHashed defaultLogLossHashedSize, theRest)]
|
||||
@ -216,6 +222,7 @@ getMetricOrdering ClippEU = TheHigherTheBetter
|
||||
getMetricOrdering (FMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (MacroFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (SoftFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering (ProbabilisticSoftFMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering NMI = TheHigherTheBetter
|
||||
getMetricOrdering (LogLossHashed _) = TheLowerTheBetter
|
||||
getMetricOrdering (LikelihoodHashed _) = TheHigherTheBetter
|
||||
@ -721,7 +728,7 @@ gevalCore' (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip)
|
||||
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
|
||||
|
||||
gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
||||
parseAnnotations
|
||||
parseObtainedAnnotations
|
||||
getSoftCounts
|
||||
countAgg
|
||||
(fMeasureOnCounts beta)
|
||||
@ -729,6 +736,19 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
||||
Prelude.length expected,
|
||||
Prelude.length got)
|
||||
|
||||
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
||||
parseObtainedAnnotations
|
||||
getProbabilisticSoftCounts
|
||||
probabilisticSoftAgg
|
||||
(fMeasureOnProbabilisticCounts beta)
|
||||
where probabilisticSoftAgg :: Monad m => ConduitM ([Double], [Double], Double, Int) o m ([Double], [Double], Double, Int)
|
||||
probabilisticSoftAgg = CC.foldl probabilisticSoftFolder ([], [], fromInteger 0, 0)
|
||||
probabilisticSoftFolder (r1, p1, g1, e1) (r2, p2, g2, e2) = (r1 ++ r2, p1 ++ p2, g1 + g2, e1 + e2)
|
||||
fMeasureOnProbabilisticCounts :: Double -> ([Double], [Double], Double, Int) -> Double
|
||||
fMeasureOnProbabilisticCounts beta (results, probs, got, nbExpected) = weightedHarmonicMean beta calibrationMeasure recall
|
||||
where calibrationMeasure = softCalibration results probs
|
||||
recall = got /. nbExpected
|
||||
|
||||
gevalCore' ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep
|
||||
where
|
||||
parseClippings = controlledParse lineClippingsParser
|
||||
|
@ -219,7 +219,7 @@ metricReader = many $ option auto -- actually `some` should be used inst
|
||||
( long "metric" -- --metric might be in the config.txt file...
|
||||
<> short 'm'
|
||||
<> metavar "METRIC"
|
||||
<> help "Metric to be used - RMSE, MSE, MAE, SMAPE, Pearson, Spearman, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MultiLabel-Likelihood, MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels, TokenAccuracy, soft F-measure (specify as Soft-F1, Soft-F2, Soft-F0.25) or CharMatch" )
|
||||
<> help "Metric to be used - RMSE, MSE, MAE, SMAPE, Pearson, Spearman, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MultiLabel-Likelihood, MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels, TokenAccuracy, soft F-measure (specify as Soft-F1, Soft-F2, Soft-F0.25), probabilistic soft F-measure (specify as Probabilistic-Soft-F1, Probabilistic-Soft-F2, Probabilistic-Soft-F0.25) or CharMatch" )
|
||||
|
||||
altMetricReader :: Parser (Maybe Metric)
|
||||
altMetricReader = optional $ option auto
|
||||
|
@ -2,10 +2,10 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module GEval.PrecisionRecall(calculateMAPForOneResult,
|
||||
fMeasure, f1Measure, f2Measure, precision, recall,
|
||||
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
|
||||
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
|
||||
precisionAndRecall, precisionAndRecallFromCounts,
|
||||
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch)
|
||||
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching)
|
||||
where
|
||||
|
||||
import GEval.Common
|
||||
@ -41,10 +41,13 @@ fMeasure :: Double -- ^ beta parameter
|
||||
-> [a] -- ^ the ground truth
|
||||
-> [b] -- ^ what we got
|
||||
-> Double -- ^ f-Measure
|
||||
fMeasure beta matchingFun expected got =
|
||||
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||
fMeasure beta matchingFun expected got = weightedHarmonicMean beta p r
|
||||
where (p, r) = precisionAndRecall matchingFun expected got
|
||||
|
||||
weightedHarmonicMean :: Double -> Double -> Double -> Double
|
||||
weightedHarmonicMean beta x y =
|
||||
(1 + betaSquared) * x * y `safeDoubleDiv` (betaSquared * x + y)
|
||||
where betaSquared = beta ^ 2
|
||||
(p, r) = precisionAndRecall matchingFun expected got
|
||||
|
||||
f2MeasureOnCounts :: ConvertibleToDouble n => (n, Int, Int) -> Double
|
||||
f2MeasureOnCounts = fMeasureOnCounts 2.0
|
||||
@ -125,11 +128,14 @@ buildGraph matchFun expected got = (b, e, g)
|
||||
-- the weight are assumed to be between 0.0 and 1.0
|
||||
weightedMaxMatch :: (a -> b -> Double) -> [a] -> [b] -> Double
|
||||
weightedMaxMatch matchFun expected got = (fromIntegral $ length matching) - score
|
||||
where (matching, score) = hungarianMethodDouble complementWeightArray
|
||||
where (matching, score) = weightedMaxMatching matchFun expected got
|
||||
|
||||
weightedMaxMatching :: (a -> b -> Double) -> [a] -> [b] -> ([(Int, Int)], Double)
|
||||
weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeightArray
|
||||
-- unfortunately `hungarianMethodDouble` looks
|
||||
-- for minimal bipartite matching
|
||||
-- rather than the maximal one
|
||||
complementWeightArray = DAI.array ((1, 1), (m, n)) weightList
|
||||
where complementWeightArray = DAI.array ((1, 1), (m, n)) weightList
|
||||
m = length expected
|
||||
n = length got
|
||||
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,
|
||||
|
13
test/Spec.hs
13
test/Spec.hs
@ -248,6 +248,11 @@ main = hspec $ do
|
||||
runGEvalTest "soft-f1-simple" `shouldReturnAlmost` 0.33333333333333
|
||||
it "perfect test" $ do
|
||||
runGEvalTest "soft-f1-perfect" `shouldReturnAlmost` 1.0
|
||||
describe "Probabilistic-Soft-F1" $ do
|
||||
it "simple test" $ do
|
||||
runGEvalTest "probabilistic-soft-f1-simple" `shouldReturnAlmost` 0.33333333333333
|
||||
it "simple test with perfect calibration" $ do
|
||||
runGEvalTest "probabilistic-soft-f1-calibrated" `shouldReturnAlmost` 0.88888888888
|
||||
describe "test edit-distance library" $ do
|
||||
it "for handling UTF8" $ do
|
||||
levenshteinDistance defaultEditCosts "źdźbło" "źd好bło" `shouldBe` 1
|
||||
@ -307,10 +312,10 @@ main = hspec $ do
|
||||
it "empty (just spaces)" $ do
|
||||
parseAnnotations " " `shouldBe` Right []
|
||||
it "match score" $ do
|
||||
matchScore (Annotation "x" (IS.fromList [3..6])) (Annotation "y" (IS.fromList [3..6])) `shouldBeAlmost` 0.0
|
||||
matchScore (Annotation "x" (IS.fromList [3..6])) (Annotation "x" (IS.fromList [3..6])) `shouldBeAlmost` 1.0
|
||||
matchScore (Annotation "x" (IS.fromList [123..140])) (Annotation "x" (IS.fromList [125..130])) `shouldBeAlmost` 0.33333
|
||||
matchScore (Annotation "x" (IS.fromList [3..4])) (Annotation "x" (IS.fromList [2..13])) `shouldBeAlmost` 0.1666666
|
||||
matchScore (Annotation "x" (IS.fromList [3..6])) (ObtainedAnnotation (Annotation "y" (IS.fromList [3..6])) 1.0) `shouldBeAlmost` 0.0
|
||||
matchScore (Annotation "x" (IS.fromList [3..6])) (ObtainedAnnotation (Annotation "x" (IS.fromList [3..6])) 1.0) `shouldBeAlmost` 1.0
|
||||
matchScore (Annotation "x" (IS.fromList [123..140])) (ObtainedAnnotation (Annotation "x" (IS.fromList [125..130])) 1.0) `shouldBeAlmost` 0.33333
|
||||
matchScore (Annotation "x" (IS.fromList [3..4])) (ObtainedAnnotation (Annotation "x" (IS.fromList [2..13])) 1.0) `shouldBeAlmost` 0.1666666
|
||||
describe "BIO format" $ do
|
||||
it "just parse" $ do
|
||||
let (Right r) = parseOnly (bioSequenceParser <* endOfInput) "O B-city/NEW_YORK I-city B-city/KALISZ I-city O B-name"
|
||||
|
@ -0,0 +1,5 @@
|
||||
foo:1:0.8
|
||||
foo:1:0.8
|
||||
foo:1:0.8
|
||||
bar:1:0.8
|
||||
foo:1:0.8
|
|
@ -0,0 +1 @@
|
||||
--metric Probabilistic-Soft-F1
|
@ -0,0 +1,5 @@
|
||||
foo:1
|
||||
|
||||
foo:1
|
||||
bar:1
|
||||
foo:1
|
|
@ -0,0 +1,3 @@
|
||||
foo:2-13:1.0
|
||||
bar:3-7:1.0
|
||||
baz:10-16:1.0 foo:123-140:1.0
|
|
@ -0,0 +1 @@
|
||||
--metric Probabilistic-Soft-F1
|
@ -0,0 +1,3 @@
|
||||
foo:3-4 baz:3-7 foo:12
|
||||
|
||||
baz:10-16 foo:125-130
|
|
Loading…
Reference in New Issue
Block a user