Towards generalised soft F-Measure

This commit is contained in:
Filip Gralinski 2019-09-07 13:07:14 +02:00
parent e5efd90bc3
commit c011ba3962
5 changed files with 53 additions and 26 deletions

View File

@ -1,9 +1,10 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module GEval.Annotation
(parseAnnotations, Annotation(..),
parseObtainedAnnotations, ObtainedAnnotation(..),
matchScore, getProbabilisticSoftCounts, intSetParser)
matchScore, intSetParser)
where
import qualified Data.IntSet as IS
@ -13,8 +14,8 @@ import Data.Attoparsec.Text
import Data.Attoparsec.Combinator
import Control.Applicative
import GEval.Common (sepByWhitespaces, (/.))
import GEval.Probability
import Data.Char
import Data.List (find)
import Data.Maybe (fromMaybe)
import GEval.PrecisionRecall(weightedMaxMatching)
@ -25,8 +26,16 @@ data Annotation = Annotation T.Text IS.IntSet
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
deriving (Eq, Show)
getProb :: ObtainedAnnotation -> Double
getProb (ObtainedAnnotation _ p) = p
instance EntityWithProbability ObtainedAnnotation where
type BareEntity ObtainedAnnotation = Annotation
getBareEntity (ObtainedAnnotation annotation _) = annotation
getProbabilityAsDouble (ObtainedAnnotation _ p) = p
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
| otherwise = 0.0
where intSetLength = Prelude.length . IS.toList
intersect = intSetA `IS.intersection` intSetB
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
@ -61,21 +70,3 @@ intervalParser = do
startIx <- decimal
endIx <- (string "-" *> decimal <|> pure startIx)
pure $ IS.fromList [startIx..endIx]
matchScore :: Annotation -> ObtainedAnnotation -> Double
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
| otherwise = 0.0
where intSetLength = Prelude.length . IS.toList
intersect = intSetA `IS.intersection` intSetB
getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int)
getProbabilisticSoftCounts (expected, got) = (results, (map getProb got),
gotMass,
length expected)
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching
results = map findResult [1..(length got)]
findResult j = case find (\(i, j') -> j' == j) $ matching of
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
Nothing -> 0.0
(matching, _) = weightedMaxMatching matchScore expected got

View File

@ -631,7 +631,7 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
parseObtainedAnnotations
getProbabilisticSoftCounts
getProbabilisticCounts
probabilisticSoftAgg
(fMeasureOnProbabilisticCounts beta)
loessGraph

View File

@ -5,15 +5,17 @@ module GEval.PrecisionRecall(calculateMAPForOneResult,
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
precisionAndRecall, precisionAndRecallFromCounts,
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching)
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching,
getProbabilisticCounts)
where
import GEval.Common
import GEval.Probability
import Data.Graph.Inductive
import Data.Graph.Inductive.Query.MaxFlow
import Data.List (nub, foldl')
import Data.List (find, foldl', nub)
import Data.Algorithm.Munkres
import qualified Data.Array.IArray as DAI
@ -140,3 +142,16 @@ weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeig
n = length got
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,
(j, y) <- zip [1..n] got]
getProbabilisticCounts :: EntityWithProbability e => ([BareEntity e], [e]) -> ([Double], [Double], Double, Int)
getProbabilisticCounts (expected, got) = (results, (map getProbabilityAsDouble got),
gotMass,
length expected)
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProbabilityAsDouble (got !! (j - 1)))) matching
results = map findResult [1..(length got)]
findResult j = case find (\(i, j') -> j' == j) $ matching of
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
Nothing -> 0.0
(matching, _) = weightedMaxMatching matchScore expected got

View File

@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module GEval.ProbList
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList)
@ -9,12 +10,21 @@ import qualified Data.Text as T
import GEval.Common
import GEval.Probability
data ProbList = ProbList [WordWithProb]
deriving (Show)
data WordWithProb = WordWithProb T.Text Probability
deriving (Show)
instance EntityWithProbability WordWithProb where
type BareEntity WordWithProb = T.Text
getBareEntity (WordWithProb w _) = w
getProbability (WordWithProb _ p) = p
matchScore w1 (WordWithProb w2 _)
| w1 == w2 = 1.0
| otherwise = 0.0
parseIntoWordWithProb :: T.Text -> WordWithProb
parseIntoWordWithProb t =
-- If anything is wrong the whole word is treated as a label,

View File

@ -1,6 +1,8 @@
{-# LANGUAGE TypeFamilies #-}
module GEval.Probability
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero)
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero,
EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity)
where
newtype Probability = P { getP :: Double }
@ -19,3 +21,12 @@ probabilityOne = mkProbability 1.0
probabilityZero :: Probability
probabilityZero = mkProbability 0.0
class EntityWithProbability e where
type BareEntity e :: *
getProbability :: e -> Probability
getProbability = mkProbability . getProbabilityAsDouble
getProbabilityAsDouble :: e -> Double
getProbabilityAsDouble = getP . getProbability
getBareEntity :: e -> BareEntity e
matchScore :: BareEntity e -> e -> Double