Towards generalised soft F-Measure

This commit is contained in:
Filip Gralinski 2019-09-07 13:07:14 +02:00
parent e5efd90bc3
commit c011ba3962
5 changed files with 53 additions and 26 deletions

View File

@ -1,9 +1,10 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module GEval.Annotation module GEval.Annotation
(parseAnnotations, Annotation(..), (parseAnnotations, Annotation(..),
parseObtainedAnnotations, ObtainedAnnotation(..), parseObtainedAnnotations, ObtainedAnnotation(..),
matchScore, getProbabilisticSoftCounts, intSetParser) matchScore, intSetParser)
where where
import qualified Data.IntSet as IS import qualified Data.IntSet as IS
@ -13,8 +14,8 @@ import Data.Attoparsec.Text
import Data.Attoparsec.Combinator import Data.Attoparsec.Combinator
import Control.Applicative import Control.Applicative
import GEval.Common (sepByWhitespaces, (/.)) import GEval.Common (sepByWhitespaces, (/.))
import GEval.Probability
import Data.Char import Data.Char
import Data.List (find)
import Data.Maybe (fromMaybe) import Data.Maybe (fromMaybe)
import GEval.PrecisionRecall(weightedMaxMatching) import GEval.PrecisionRecall(weightedMaxMatching)
@ -25,8 +26,16 @@ data Annotation = Annotation T.Text IS.IntSet
data ObtainedAnnotation = ObtainedAnnotation Annotation Double data ObtainedAnnotation = ObtainedAnnotation Annotation Double
deriving (Eq, Show) deriving (Eq, Show)
getProb :: ObtainedAnnotation -> Double instance EntityWithProbability ObtainedAnnotation where
getProb (ObtainedAnnotation _ p) = p type BareEntity ObtainedAnnotation = Annotation
getBareEntity (ObtainedAnnotation annotation _) = annotation
getProbabilityAsDouble (ObtainedAnnotation _ p) = p
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
| otherwise = 0.0
where intSetLength = Prelude.length . IS.toList
intersect = intSetA `IS.intersection` intSetB
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation] parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
@ -61,21 +70,3 @@ intervalParser = do
startIx <- decimal startIx <- decimal
endIx <- (string "-" *> decimal <|> pure startIx) endIx <- (string "-" *> decimal <|> pure startIx)
pure $ IS.fromList [startIx..endIx] pure $ IS.fromList [startIx..endIx]
matchScore :: Annotation -> ObtainedAnnotation -> Double
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
| otherwise = 0.0
where intSetLength = Prelude.length . IS.toList
intersect = intSetA `IS.intersection` intSetB
getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int)
getProbabilisticSoftCounts (expected, got) = (results, (map getProb got),
gotMass,
length expected)
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching
results = map findResult [1..(length got)]
findResult j = case find (\(i, j') -> j' == j) $ matching of
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
Nothing -> 0.0
(matching, _) = weightedMaxMatching matchScore expected got

View File

@ -631,7 +631,7 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
parseObtainedAnnotations parseObtainedAnnotations
getProbabilisticSoftCounts getProbabilisticCounts
probabilisticSoftAgg probabilisticSoftAgg
(fMeasureOnProbabilisticCounts beta) (fMeasureOnProbabilisticCounts beta)
loessGraph loessGraph

View File

@ -5,15 +5,17 @@ module GEval.PrecisionRecall(calculateMAPForOneResult,
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall, weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder, fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
precisionAndRecall, precisionAndRecallFromCounts, precisionAndRecall, precisionAndRecallFromCounts,
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching) maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching,
getProbabilisticCounts)
where where
import GEval.Common import GEval.Common
import GEval.Probability
import Data.Graph.Inductive import Data.Graph.Inductive
import Data.Graph.Inductive.Query.MaxFlow import Data.Graph.Inductive.Query.MaxFlow
import Data.List (nub, foldl') import Data.List (find, foldl', nub)
import Data.Algorithm.Munkres import Data.Algorithm.Munkres
import qualified Data.Array.IArray as DAI import qualified Data.Array.IArray as DAI
@ -140,3 +142,16 @@ weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeig
n = length got n = length got
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected, weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,
(j, y) <- zip [1..n] got] (j, y) <- zip [1..n] got]
getProbabilisticCounts :: EntityWithProbability e => ([BareEntity e], [e]) -> ([Double], [Double], Double, Int)
getProbabilisticCounts (expected, got) = (results, (map getProbabilityAsDouble got),
gotMass,
length expected)
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProbabilityAsDouble (got !! (j - 1)))) matching
results = map findResult [1..(length got)]
findResult j = case find (\(i, j') -> j' == j) $ matching of
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
Nothing -> 0.0
(matching, _) = weightedMaxMatching matchScore expected got

View File

@ -1,4 +1,5 @@
{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module GEval.ProbList module GEval.ProbList
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList) (parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList)
@ -9,12 +10,21 @@ import qualified Data.Text as T
import GEval.Common import GEval.Common
import GEval.Probability import GEval.Probability
data ProbList = ProbList [WordWithProb] data ProbList = ProbList [WordWithProb]
deriving (Show) deriving (Show)
data WordWithProb = WordWithProb T.Text Probability data WordWithProb = WordWithProb T.Text Probability
deriving (Show) deriving (Show)
instance EntityWithProbability WordWithProb where
type BareEntity WordWithProb = T.Text
getBareEntity (WordWithProb w _) = w
getProbability (WordWithProb _ p) = p
matchScore w1 (WordWithProb w2 _)
| w1 == w2 = 1.0
| otherwise = 0.0
parseIntoWordWithProb :: T.Text -> WordWithProb parseIntoWordWithProb :: T.Text -> WordWithProb
parseIntoWordWithProb t = parseIntoWordWithProb t =
-- If anything is wrong the whole word is treated as a label, -- If anything is wrong the whole word is treated as a label,

View File

@ -1,6 +1,8 @@
{-# LANGUAGE TypeFamilies #-}
module GEval.Probability module GEval.Probability
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero) (Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero,
EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity)
where where
newtype Probability = P { getP :: Double } newtype Probability = P { getP :: Double }
@ -19,3 +21,12 @@ probabilityOne = mkProbability 1.0
probabilityZero :: Probability probabilityZero :: Probability
probabilityZero = mkProbability 0.0 probabilityZero = mkProbability 0.0
class EntityWithProbability e where
type BareEntity e :: *
getProbability :: e -> Probability
getProbability = mkProbability . getProbabilityAsDouble
getProbabilityAsDouble :: e -> Double
getProbabilityAsDouble = getP . getProbability
getBareEntity :: e -> BareEntity e
matchScore :: BareEntity e -> e -> Double