diff --git a/src/GEval/Annotation.hs b/src/GEval/Annotation.hs index 2115833..abc8b59 100644 --- a/src/GEval/Annotation.hs +++ b/src/GEval/Annotation.hs @@ -1,9 +1,10 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} module GEval.Annotation (parseAnnotations, Annotation(..), parseObtainedAnnotations, ObtainedAnnotation(..), - matchScore, getProbabilisticSoftCounts, intSetParser) + matchScore, intSetParser) where import qualified Data.IntSet as IS @@ -13,8 +14,8 @@ import Data.Attoparsec.Text import Data.Attoparsec.Combinator import Control.Applicative import GEval.Common (sepByWhitespaces, (/.)) +import GEval.Probability import Data.Char -import Data.List (find) import Data.Maybe (fromMaybe) import GEval.PrecisionRecall(weightedMaxMatching) @@ -25,8 +26,16 @@ data Annotation = Annotation T.Text IS.IntSet data ObtainedAnnotation = ObtainedAnnotation Annotation Double deriving (Eq, Show) -getProb :: ObtainedAnnotation -> Double -getProb (ObtainedAnnotation _ p) = p +instance EntityWithProbability ObtainedAnnotation where + type BareEntity ObtainedAnnotation = Annotation + getBareEntity (ObtainedAnnotation annotation _) = annotation + getProbabilityAsDouble (ObtainedAnnotation _ p) = p + matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _) + | labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB) + | otherwise = 0.0 + where intSetLength = Prelude.length . IS.toList + intersect = intSetA `IS.intersection` intSetB + parseObtainedAnnotations :: T.Text -> Either String [ObtainedAnnotation] parseObtainedAnnotations t = parseOnly (obtainedAnnotationsParser <* endOfInput) t @@ -61,21 +70,3 @@ intervalParser = do startIx <- decimal endIx <- (string "-" *> decimal <|> pure startIx) pure $ IS.fromList [startIx..endIx] - -matchScore :: Annotation -> ObtainedAnnotation -> Double -matchScore (Annotation labelA intSetA) (ObtainedAnnotation (Annotation labelB intSetB) _) - | labelA == labelB = (intSetLength intersect) /. (intSetLength $ intSetA `IS.union` intSetB) - | otherwise = 0.0 - where intSetLength = Prelude.length . IS.toList - intersect = intSetA `IS.intersection` intSetB - -getProbabilisticSoftCounts :: ([Annotation], [ObtainedAnnotation]) -> ([Double], [Double], Double, Int) -getProbabilisticSoftCounts (expected, got) = (results, (map getProb got), - gotMass, - length expected) - where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProb (got !! (j - 1)))) matching - results = map findResult [1..(length got)] - findResult j = case find (\(i, j') -> j' == j) $ matching of - Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1)) - Nothing -> 0.0 - (matching, _) = weightedMaxMatching matchScore expected got diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 526129f..f320243 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -631,7 +631,7 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations parseObtainedAnnotations - getProbabilisticSoftCounts + getProbabilisticCounts probabilisticSoftAgg (fMeasureOnProbabilisticCounts beta) loessGraph diff --git a/src/GEval/PrecisionRecall.hs b/src/GEval/PrecisionRecall.hs index 1ebc0f7..b754ea8 100644 --- a/src/GEval/PrecisionRecall.hs +++ b/src/GEval/PrecisionRecall.hs @@ -5,15 +5,17 @@ module GEval.PrecisionRecall(calculateMAPForOneResult, weightedHarmonicMean, fMeasure, f1Measure, f2Measure, precision, recall, fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, countFolder, precisionAndRecall, precisionAndRecallFromCounts, - maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching) + maxMatch, maxMatchOnOrdered, getCounts, weightedMaxMatch, weightedMaxMatching, + getProbabilisticCounts) where import GEval.Common +import GEval.Probability import Data.Graph.Inductive import Data.Graph.Inductive.Query.MaxFlow -import Data.List (nub, foldl') +import Data.List (find, foldl', nub) import Data.Algorithm.Munkres import qualified Data.Array.IArray as DAI @@ -140,3 +142,16 @@ weightedMaxMatching matchFun expected got = hungarianMethodDouble complementWeig n = length got weightList = [((i, j), 1.0 - (matchFun x y)) | (i, x) <- zip [1..m] expected, (j, y) <- zip [1..n] got] + + + +getProbabilisticCounts :: EntityWithProbability e => ([BareEntity e], [e]) -> ([Double], [Double], Double, Int) +getProbabilisticCounts (expected, got) = (results, (map getProbabilityAsDouble got), + gotMass, + length expected) + where gotMass = sum $ map (\(i, j) -> (matchScore (expected !! (i - 1)) (got !! (j - 1))) * (getProbabilityAsDouble (got !! (j - 1)))) matching + results = map findResult [1..(length got)] + findResult j = case find (\(i, j') -> j' == j) $ matching of + Just (i, _) -> matchScore (expected !! (i - 1)) (got !! (j - 1)) + Nothing -> 0.0 + (matching, _) = weightedMaxMatching matchScore expected got diff --git a/src/GEval/ProbList.hs b/src/GEval/ProbList.hs index af07138..cb66c1a 100644 --- a/src/GEval/ProbList.hs +++ b/src/GEval/ProbList.hs @@ -1,4 +1,5 @@ {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE TypeFamilies #-} module GEval.ProbList (parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList) @@ -9,12 +10,21 @@ import qualified Data.Text as T import GEval.Common import GEval.Probability + data ProbList = ProbList [WordWithProb] deriving (Show) data WordWithProb = WordWithProb T.Text Probability deriving (Show) +instance EntityWithProbability WordWithProb where + type BareEntity WordWithProb = T.Text + getBareEntity (WordWithProb w _) = w + getProbability (WordWithProb _ p) = p + matchScore w1 (WordWithProb w2 _) + | w1 == w2 = 1.0 + | otherwise = 0.0 + parseIntoWordWithProb :: T.Text -> WordWithProb parseIntoWordWithProb t = -- If anything is wrong the whole word is treated as a label, diff --git a/src/GEval/Probability.hs b/src/GEval/Probability.hs index fac462e..c938953 100644 --- a/src/GEval/Probability.hs +++ b/src/GEval/Probability.hs @@ -1,6 +1,8 @@ +{-# LANGUAGE TypeFamilies #-} module GEval.Probability - (Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero) + (Probability, getP, isProbability, mkProbability, probabilityOne, probabilityZero, + EntityWithProbability, getProbability, getProbabilityAsDouble, getBareEntity, matchScore, BareEntity) where newtype Probability = P { getP :: Double } @@ -19,3 +21,12 @@ probabilityOne = mkProbability 1.0 probabilityZero :: Probability probabilityZero = mkProbability 0.0 + +class EntityWithProbability e where + type BareEntity e :: * + getProbability :: e -> Probability + getProbability = mkProbability . getProbabilityAsDouble + getProbabilityAsDouble :: e -> Double + getProbabilityAsDouble = getP . getProbability + getBareEntity :: e -> BareEntity e + matchScore :: BareEntity e -> e -> Double