finish general procedure for precision, recall and F-measure

This commit is contained in:
Filip Gralinski 2016-08-01 22:47:43 +02:00 committed by Filip Gralinski
parent b14e6a38a2
commit c3106a1ad6
3 changed files with 92 additions and 27 deletions

View File

@ -19,6 +19,7 @@ library
GEval.CreateChallenge
, GEval.OptionsParser
, GEval.BLEU
, GEval.PrecisionRecall
build-depends: base >= 4.7 && < 5
, cond
, conduit
@ -33,6 +34,7 @@ library
, split
, text
, unix
, fgl
default-language: Haskell2010
executable geval
@ -42,6 +44,7 @@ executable geval
build-depends: base
, geval
, optparse-applicative
, fgl
default-language: Haskell2010
test-suite geval-test

View File

@ -1,46 +1,86 @@
module GEval.PrecisionRecall
{-# LANGUAGE PartialTypeSignatures #-}
module GEval.PrecisionRecall(fMeasure, f1Measure, f2Measure, precision, recall,
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts,
precisionAndRecall, precisionAndRecallFromCounts)
where
import Data.Graph.Inductive
import Data.Graph.Inductive.Query.MaxFlow
import Debug.Trace
f2Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
f2Measure = fMeasure 2.0
f1Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
f1Measure = fMeasure 1.0
fMeasure :: Double -> (a -> b -> Bool) -> [a] -> [b] -> Double
fMeasure beta matchingFun expected got =
(1 + betaSquared) * (p + r) / (betaSquared * p + 2)
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
where betaSquared = beta ^ 2
(p, r) = precisionAndRecall matchingFun expected got
f2MeasureOnCounts :: (Int, Int, Int) -> Double
f2MeasureOnCounts = fMeasureOnCounts 2.0
f1MeasureOnCounts :: (Int, Int, Int) -> Double
f1MeasureOnCounts = fMeasureOnCounts 1.0
fMeasureOnCounts :: Double -> (Int, Int, Int) -> Double
fMeasureOnCounts beta (tp, nbExpected, nbGot) =
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
where betaSquared = beta ^ 2
(p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot)
precisionAndRecall :: (a -> b -> Bool) -> [a] -> [b] -> (Double, Double)
precisionAndRecall matchFun expected got = precisionAndRecallFromCounts cs
where cs = counts matchFun expected got
precisionAndRecall matchFun expected got
= precisionAndRecallFromCounts (tp, length expected, length got)
where tp = maxMatch matchFun expected got
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
precisionAndRecallFromCounts (fp, fn, tp) = (tp / (tp+fp+epsilon), tp / (tp+fn+epsilon))
where epsilon = 0.000001
precisionAndRecallFromCounts (tp, nbExpected, nbGot) =
(tp `safeDiv` nbGot, tp `safeDiv` nbExpected)
-- division with possible zero
safeDiv :: Int -> Int -> Double
safeDiv _ 0 = 1.0
safeDiv x y = (fromIntegral x) / (fromIntegral y)
safeDoubleDiv :: Double -> Double -> Double
safeDoubleDiv _ 0.0 = 0.0
safeDoubleDiv x y = x / y
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
precision = first . precisionAndRecall
precision matchFun expected got = fst $ precisionAndRecall matchFun expected got
recall :: (a -> b -> Bool) -> [a] -> [b] -> Double
recall = second . precisionAndRecall
recall matchFun expected got = snd $ precisionAndRecall matchFun expected got
counts :: (a -> b -> Bool) -> [a] -> [b] -> (Int, Int, Int)
counts matchFun expected got = f
-- counting maximum match with maximum bipartite matching
-- (we build an auxiliary graph and do a max-flow on this)
maxMatch :: (a -> b -> Bool) -> [a] -> [b] -> Int
maxMatch matchFun expected got = mf
where (b, e, g) = buildGraph matchFun expected got
mf = maxFlow g (fst b) (fst e)
countStep :: ([a], (Int, Int)) -> b -> ([a], (Int, Int))
countStep (expectedLeft, fp, tp) itemGot =
(rest, case firstMatched of
Just _ -> (fp, tp+1)
Nothing -> (fp+1, tp))
where (rest, firstMatched) = findFirst ((flip matchingFun) itemGot) expectedLeft
findFirst :: (a -> Bool) -> [a] -> ([a], Maybe a)
findFirst _ [] -> ([], Nothing)
findFirst condition (x:xs) = if (condition x)
then
(xs, Just x)
else
(x:xs', m)
where
(xs',m) = findFirst condition xs
buildGraph :: (a -> b -> Bool) -> [a] -> [b] -> (LNode Int, LNode Int, Gr Int Int)
buildGraph matchFun expected got = (b, e, g)
where ((b, e), (_, g)) = buildGraph' matchFun expected got
buildGraph' matchFun expected got =
run empty $
do b <- insMapNodeM 0
e <- insMapNodeM 1
mapM insMapNodeM [2..1+(length expected)+(length got)]
insMapEdgesM $ map (\n -> (0, n, 1)) expectedIxs
insMapEdgesM $ map (\m -> (m, 1, 1)) gotIxs
insMapEdgesM $ map (\(n,m) -> (n, m, 1))
$ filter (\(n, m) -> matchFun (expected !! (n-2)) (got !! (m-2-(length expected))))
[(x,y) | x <- expectedIxs, y <- gotIxs]
return (b,e)
where expectedIxs = [2..1+(length expected)]
gotIxs = [2+(length expected)..1+(length expected)+(length got)]

View File

@ -3,6 +3,7 @@ import Test.Hspec
import GEval.Core
import GEval.OptionsParser
import GEval.BLEU
import GEval.PrecisionRecall
import Options.Applicative
import qualified Test.HUnit as HU
@ -47,7 +48,26 @@ main = hspec $ do
runGEvalTest "unexpected-data" `shouldThrow` (== UnexpectedData "input does not start with a digit")
it "unwanted data is handled" $
runGEvalTest "unwanted-data" `shouldThrow` (== UnexpectedData "number expected")
describe "precision and recall" $ do
it "null test" $ do
precision neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
recall neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
f1Measure neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
it "basic test" $ do
precision testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.3333333333333333
recall testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.66666666666666666
f1Measure testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.444444444444444
neverMatch :: Char -> Int -> Bool
neverMatch _ _ = False
testMatchFun :: Char -> Int -> Bool
testMatchFun 'a' 1 = True
testMatchFun 'a' 2 = True
testMatchFun 'a' 3 = True
testMatchFun 'b' 1 = True
testMatchFun 'c' 1 = True
testMatchFun _ _ = False
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
extractVal (Right (Just val)) = return val
@ -72,10 +92,12 @@ instance AEq Double where
x =~ y = abs ( x - y ) < (1.0e-4 :: Double)
(@=~?) :: (Show a, AEq a) => a -> a -> HU.Assertion
(@=~?) expected actual = expected =~ actual HU.@? assertionMsg
(@=~?) actual expected = expected =~ actual HU.@? assertionMsg
where
assertionMsg = "Expected : " ++ show expected ++
"\nActual : " ++ show actual
shouldBeAlmost got expected = got @=~? expected
shouldReturnAlmost :: (AEq a, Show a, Eq a) => IO a -> a -> Expectation
shouldReturnAlmost action expected = action >>= (@=~? expected)