finish general procedure for precision, recall and F-measure
This commit is contained in:
parent
b14e6a38a2
commit
c3106a1ad6
@ -19,6 +19,7 @@ library
|
||||
GEval.CreateChallenge
|
||||
, GEval.OptionsParser
|
||||
, GEval.BLEU
|
||||
, GEval.PrecisionRecall
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, cond
|
||||
, conduit
|
||||
@ -33,6 +34,7 @@ library
|
||||
, split
|
||||
, text
|
||||
, unix
|
||||
, fgl
|
||||
default-language: Haskell2010
|
||||
|
||||
executable geval
|
||||
@ -42,6 +44,7 @@ executable geval
|
||||
build-depends: base
|
||||
, geval
|
||||
, optparse-applicative
|
||||
, fgl
|
||||
default-language: Haskell2010
|
||||
|
||||
test-suite geval-test
|
||||
|
@ -1,46 +1,86 @@
|
||||
module GEval.PrecisionRecall
|
||||
{-# LANGUAGE PartialTypeSignatures #-}
|
||||
|
||||
module GEval.PrecisionRecall(fMeasure, f1Measure, f2Measure, precision, recall,
|
||||
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts,
|
||||
precisionAndRecall, precisionAndRecallFromCounts)
|
||||
where
|
||||
|
||||
import Data.Graph.Inductive
|
||||
import Data.Graph.Inductive.Query.MaxFlow
|
||||
|
||||
import Debug.Trace
|
||||
|
||||
f2Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||
f2Measure = fMeasure 2.0
|
||||
|
||||
f1Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||
f1Measure = fMeasure 1.0
|
||||
|
||||
fMeasure :: Double -> (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||
fMeasure beta matchingFun expected got =
|
||||
(1 + betaSquared) * (p + r) / (betaSquared * p + 2)
|
||||
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||
where betaSquared = beta ^ 2
|
||||
(p, r) = precisionAndRecall matchingFun expected got
|
||||
|
||||
f2MeasureOnCounts :: (Int, Int, Int) -> Double
|
||||
f2MeasureOnCounts = fMeasureOnCounts 2.0
|
||||
|
||||
f1MeasureOnCounts :: (Int, Int, Int) -> Double
|
||||
f1MeasureOnCounts = fMeasureOnCounts 1.0
|
||||
|
||||
fMeasureOnCounts :: Double -> (Int, Int, Int) -> Double
|
||||
fMeasureOnCounts beta (tp, nbExpected, nbGot) =
|
||||
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||
where betaSquared = beta ^ 2
|
||||
(p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot)
|
||||
|
||||
precisionAndRecall :: (a -> b -> Bool) -> [a] -> [b] -> (Double, Double)
|
||||
precisionAndRecall matchFun expected got = precisionAndRecallFromCounts cs
|
||||
where cs = counts matchFun expected got
|
||||
precisionAndRecall matchFun expected got
|
||||
= precisionAndRecallFromCounts (tp, length expected, length got)
|
||||
where tp = maxMatch matchFun expected got
|
||||
|
||||
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
|
||||
precisionAndRecallFromCounts (fp, fn, tp) = (tp / (tp+fp+epsilon), tp / (tp+fn+epsilon))
|
||||
where epsilon = 0.000001
|
||||
precisionAndRecallFromCounts (tp, nbExpected, nbGot) =
|
||||
(tp `safeDiv` nbGot, tp `safeDiv` nbExpected)
|
||||
|
||||
-- division with possible zero
|
||||
safeDiv :: Int -> Int -> Double
|
||||
safeDiv _ 0 = 1.0
|
||||
safeDiv x y = (fromIntegral x) / (fromIntegral y)
|
||||
|
||||
safeDoubleDiv :: Double -> Double -> Double
|
||||
safeDoubleDiv _ 0.0 = 0.0
|
||||
safeDoubleDiv x y = x / y
|
||||
|
||||
|
||||
|
||||
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||
precision = first . precisionAndRecall
|
||||
precision matchFun expected got = fst $ precisionAndRecall matchFun expected got
|
||||
|
||||
recall :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||
recall = second . precisionAndRecall
|
||||
recall matchFun expected got = snd $ precisionAndRecall matchFun expected got
|
||||
|
||||
counts :: (a -> b -> Bool) -> [a] -> [b] -> (Int, Int, Int)
|
||||
counts matchFun expected got = f
|
||||
-- counting maximum match with maximum bipartite matching
|
||||
-- (we build an auxiliary graph and do a max-flow on this)
|
||||
maxMatch :: (a -> b -> Bool) -> [a] -> [b] -> Int
|
||||
maxMatch matchFun expected got = mf
|
||||
where (b, e, g) = buildGraph matchFun expected got
|
||||
mf = maxFlow g (fst b) (fst e)
|
||||
|
||||
|
||||
countStep :: ([a], (Int, Int)) -> b -> ([a], (Int, Int))
|
||||
countStep (expectedLeft, fp, tp) itemGot =
|
||||
(rest, case firstMatched of
|
||||
Just _ -> (fp, tp+1)
|
||||
Nothing -> (fp+1, tp))
|
||||
where (rest, firstMatched) = findFirst ((flip matchingFun) itemGot) expectedLeft
|
||||
|
||||
findFirst :: (a -> Bool) -> [a] -> ([a], Maybe a)
|
||||
findFirst _ [] -> ([], Nothing)
|
||||
findFirst condition (x:xs) = if (condition x)
|
||||
then
|
||||
(xs, Just x)
|
||||
else
|
||||
(x:xs', m)
|
||||
where
|
||||
(xs',m) = findFirst condition xs
|
||||
buildGraph :: (a -> b -> Bool) -> [a] -> [b] -> (LNode Int, LNode Int, Gr Int Int)
|
||||
buildGraph matchFun expected got = (b, e, g)
|
||||
where ((b, e), (_, g)) = buildGraph' matchFun expected got
|
||||
buildGraph' matchFun expected got =
|
||||
run empty $
|
||||
do b <- insMapNodeM 0
|
||||
e <- insMapNodeM 1
|
||||
mapM insMapNodeM [2..1+(length expected)+(length got)]
|
||||
insMapEdgesM $ map (\n -> (0, n, 1)) expectedIxs
|
||||
insMapEdgesM $ map (\m -> (m, 1, 1)) gotIxs
|
||||
insMapEdgesM $ map (\(n,m) -> (n, m, 1))
|
||||
$ filter (\(n, m) -> matchFun (expected !! (n-2)) (got !! (m-2-(length expected))))
|
||||
[(x,y) | x <- expectedIxs, y <- gotIxs]
|
||||
return (b,e)
|
||||
where expectedIxs = [2..1+(length expected)]
|
||||
gotIxs = [2+(length expected)..1+(length expected)+(length got)]
|
||||
|
24
test/Spec.hs
24
test/Spec.hs
@ -3,6 +3,7 @@ import Test.Hspec
|
||||
import GEval.Core
|
||||
import GEval.OptionsParser
|
||||
import GEval.BLEU
|
||||
import GEval.PrecisionRecall
|
||||
import Options.Applicative
|
||||
import qualified Test.HUnit as HU
|
||||
|
||||
@ -47,7 +48,26 @@ main = hspec $ do
|
||||
runGEvalTest "unexpected-data" `shouldThrow` (== UnexpectedData "input does not start with a digit")
|
||||
it "unwanted data is handled" $
|
||||
runGEvalTest "unwanted-data" `shouldThrow` (== UnexpectedData "number expected")
|
||||
describe "precision and recall" $ do
|
||||
it "null test" $ do
|
||||
precision neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||
recall neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||
f1Measure neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||
it "basic test" $ do
|
||||
precision testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.3333333333333333
|
||||
recall testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.66666666666666666
|
||||
f1Measure testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.444444444444444
|
||||
|
||||
neverMatch :: Char -> Int -> Bool
|
||||
neverMatch _ _ = False
|
||||
|
||||
testMatchFun :: Char -> Int -> Bool
|
||||
testMatchFun 'a' 1 = True
|
||||
testMatchFun 'a' 2 = True
|
||||
testMatchFun 'a' 3 = True
|
||||
testMatchFun 'b' 1 = True
|
||||
testMatchFun 'c' 1 = True
|
||||
testMatchFun _ _ = False
|
||||
|
||||
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
||||
extractVal (Right (Just val)) = return val
|
||||
@ -72,10 +92,12 @@ instance AEq Double where
|
||||
x =~ y = abs ( x - y ) < (1.0e-4 :: Double)
|
||||
|
||||
(@=~?) :: (Show a, AEq a) => a -> a -> HU.Assertion
|
||||
(@=~?) expected actual = expected =~ actual HU.@? assertionMsg
|
||||
(@=~?) actual expected = expected =~ actual HU.@? assertionMsg
|
||||
where
|
||||
assertionMsg = "Expected : " ++ show expected ++
|
||||
"\nActual : " ++ show actual
|
||||
|
||||
shouldBeAlmost got expected = got @=~? expected
|
||||
|
||||
shouldReturnAlmost :: (AEq a, Show a, Eq a) => IO a -> a -> Expectation
|
||||
shouldReturnAlmost action expected = action >>= (@=~? expected)
|
||||
|
Loading…
Reference in New Issue
Block a user