Towards generalised soft F-Measure
This commit is contained in:
parent
e5efd90bc3
commit
c011ba3962
@ -1,9 +1,10 @@
|
|||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
module GEval.Annotation
|
module GEval.Annotation
|
||||||
(parseAnnotations, Annotation(..),
|
(parseAnnotations, Annotation(..),
|
||||||
parseObtainedAnnotations, ObtainedAnnotation(..),
|
parseObtainedAnnotations, ObtainedAnnotation(..),
|
||||||
matchScore, getProbabilisticSoftCounts, intSetParser)
|
matchScore, intSetParser)
|
||||||
where
|
where
|
||||||
|
|
||||||
import qualified Data.IntSet as IS
|
import qualified Data.IntSet as IS
|
||||||
@ -13,8 +14,8 @@ import Data.Attoparsec.Text
|
|||||||
import Data.Attoparsec.Combinator
|
import Data.Attoparsec.Combinator
|
||||||
import Control.Applicative
|
import Control.Applicative
|
||||||
import GEval.Common (sepByWhitespaces, (/.))
|
import GEval.Common (sepByWhitespaces, (/.))
|
||||||
|
import GEval.Probability
|
||||||
import Data.Char
|
import Data.Char
|
||||||
import Data.List (find)
|
|
||||||
import Data.Maybe (fromMaybe)
|
import Data.Maybe (fromMaybe)
|
||||||
|
|
||||||
import GEval.PrecisionRecall(weightedMaxMatching)
|
import GEval.PrecisionRecall(weightedMaxMatching)
|
||||||
@ -25,8 +26,16 @@ data Annotation = Annotation T.Text IS.IntSet
|
|||||||
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
|
data ObtainedAnnotation = ObtainedAnnotation Annotation Double
|
||||||
deriving (Eq, Show)
|
deriving (Eq, Show)
|
||||||
|
|
||||||
getProb :: ObtainedAnnotation -> Double
|
instance EntityWithProbability ObtainedAnnotation where
|
||||||
getProb (ObtainedAnnotation _ p) = p
|
type BareEntity ObtainedAnnotation = Annotation
|
||||||
|
getBareEntity (ObtainedAnnotation annotation _) = annotation
|
||||||
|
getProbabilityAsDouble (ObtainedAnnotation _ p) = p
|
||||||
|
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
|
||||||
|
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
|
||||||
|
| otherwise = 0.0
|
||||||
|
where intSetLength = Prelude.length . IS.toList
|
||||||
|
intersect = intSetA `IS.intersection` intSetB
|
||||||
|
|
||||||
|
|
||||||
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
|
parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation]
|
||||||
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
|
parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t
|
||||||
@ -61,21 +70,3 @@ intervalParser = do
|
|||||||
startIx <- decimal
|
startIx <- decimal
|
||||||
endIx <- (string "-" *> decimal <|> pure startIx)
|
endIx <- (string "-" *> decimal <|> pure startIx)
|
||||||
pure $ IS.fromList [startIx..endIx]
|
pure $ IS.fromList [startIx..endIx]
|
||||||
|
|
||||||
matchScore :: Annotation -> ObtainedAnnotation -> Double
|
|
||||||
matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _)
|
|
||||||
| labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB)
|
|
||||||
| otherwise = 0.0
|
|
||||||
where intSetLength = Prelude.length . IS.toList
|
|
||||||
intersect = intSetA `IS.intersection` intSetB
|
|
||||||
|
|
||||||
getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int)
|
|
||||||
getProbabilisticSoftCounts (expected, got) = (results, (map getProb got),
|
|
||||||
gotMass,
|
|
||||||
length expected)
|
|
||||||
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching
|
|
||||||
results = map findResult [1..(length got)]
|
|
||||||
findResult j = case find (\(i, j') -> j' == j) $ matching of
|
|
||||||
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
|
|
||||||
Nothing -> 0.0
|
|
||||||
(matching, _) = weightedMaxMatching matchScore expected got
|
|
||||||
|
@ -631,7 +631,7 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
|||||||
|
|
||||||
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
|
||||||
parseObtainedAnnotations
|
parseObtainedAnnotations
|
||||||
getProbabilisticSoftCounts
|
getProbabilisticCounts
|
||||||
probabilisticSoftAgg
|
probabilisticSoftAgg
|
||||||
(fMeasureOnProbabilisticCounts beta)
|
(fMeasureOnProbabilisticCounts beta)
|
||||||
loessGraph
|
loessGraph
|
||||||
|
@ -5,15 +5,17 @@ module GEval.PrecisionRecall(calculateMAPForOneResult,
|
|||||||
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
|
weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall,
|
||||||
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
|
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder,
|
||||||
precisionAndRecall, precisionAndRecallFromCounts,
|
precisionAndRecall, precisionAndRecallFromCounts,
|
||||||
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching)
|
maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching,
|
||||||
|
getProbabilisticCounts)
|
||||||
where
|
where
|
||||||
|
|
||||||
import GEval.Common
|
import GEval.Common
|
||||||
|
import GEval.Probability
|
||||||
|
|
||||||
import Data.Graph.Inductive
|
import Data.Graph.Inductive
|
||||||
import Data.Graph.Inductive.Query.MaxFlow
|
import Data.Graph.Inductive.Query.MaxFlow
|
||||||
|
|
||||||
import Data.List (nub, foldl')
|
import Data.List (find, foldl', nub)
|
||||||
|
|
||||||
import Data.Algorithm.Munkres
|
import Data.Algorithm.Munkres
|
||||||
import qualified Data.Array.IArray as DAI
|
import qualified Data.Array.IArray as DAI
|
||||||
@ -140,3 +142,16 @@ weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeig
|
|||||||
n = length got
|
n = length got
|
||||||
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,
|
weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected,
|
||||||
(j, y) <- zip [1..n] got]
|
(j, y) <- zip [1..n] got]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
getProbabilisticCounts :: EntityWithProbability e => ([BareEntity e], [e]) -> ([Double], [Double], Double, Int)
|
||||||
|
getProbabilisticCounts (expected, got) = (results, (map getProbabilityAsDouble got),
|
||||||
|
gotMass,
|
||||||
|
length expected)
|
||||||
|
where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProbabilityAsDouble (got !! (j - 1)))) matching
|
||||||
|
results = map findResult [1..(length got)]
|
||||||
|
findResult j = case find (\(i, j') -> j' == j) $ matching of
|
||||||
|
Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1))
|
||||||
|
Nothing -> 0.0
|
||||||
|
(matching, _) = weightedMaxMatching matchScore expected got
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
module GEval.ProbList
|
module GEval.ProbList
|
||||||
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList)
|
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList)
|
||||||
@ -9,12 +10,21 @@ import qualified Data.Text as T
|
|||||||
import GEval.Common
|
import GEval.Common
|
||||||
import GEval.Probability
|
import GEval.Probability
|
||||||
|
|
||||||
|
|
||||||
data ProbList = ProbList [WordWithProb]
|
data ProbList = ProbList [WordWithProb]
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
data WordWithProb = WordWithProb T.Text Probability
|
data WordWithProb = WordWithProb T.Text Probability
|
||||||
deriving (Show)
|
deriving (Show)
|
||||||
|
|
||||||
|
instance EntityWithProbability WordWithProb where
|
||||||
|
type BareEntity WordWithProb = T.Text
|
||||||
|
getBareEntity (WordWithProb w _) = w
|
||||||
|
getProbability (WordWithProb _ p) = p
|
||||||
|
matchScore w1 (WordWithProb w2 _)
|
||||||
|
| w1 == w2 = 1.0
|
||||||
|
| otherwise = 0.0
|
||||||
|
|
||||||
parseIntoWordWithProb :: T.Text -> WordWithProb
|
parseIntoWordWithProb :: T.Text -> WordWithProb
|
||||||
parseIntoWordWithProb t =
|
parseIntoWordWithProb t =
|
||||||
-- If anything is wrong the whole word is treated as a label,
|
-- If anything is wrong the whole word is treated as a label,
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
{-# LANGUAGE TypeFamilies #-}
|
||||||
|
|
||||||
module GEval.Probability
|
module GEval.Probability
|
||||||
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero)
|
(Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero,
|
||||||
|
EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity)
|
||||||
where
|
where
|
||||||
|
|
||||||
newtype Probability = P { getP :: Double }
|
newtype Probability = P { getP :: Double }
|
||||||
@ -19,3 +21,12 @@ probabilityOne = mkProbability 1.0
|
|||||||
|
|
||||||
probabilityZero :: Probability
|
probabilityZero :: Probability
|
||||||
probabilityZero = mkProbability 0.0
|
probabilityZero = mkProbability 0.0
|
||||||
|
|
||||||
|
class EntityWithProbability e where
|
||||||
|
type BareEntity e :: *
|
||||||
|
getProbability :: e -> Probability
|
||||||
|
getProbability = mkProbability . getProbabilityAsDouble
|
||||||
|
getProbabilityAsDouble :: e -> Double
|
||||||
|
getProbabilityAsDouble = getP . getProbability
|
||||||
|
getBareEntity :: e -> BareEntity e
|
||||||
|
matchScore :: BareEntity e -> e -> Double
|
||||||
|
Loading…
Reference in New Issue
Block a user