finish general procedure for precision, recall and F-measure
This commit is contained in:
parent
b14e6a38a2
commit
c3106a1ad6
@ -19,6 +19,7 @@ library
|
|||||||
GEval.CreateChallenge
|
GEval.CreateChallenge
|
||||||
, GEval.OptionsParser
|
, GEval.OptionsParser
|
||||||
, GEval.BLEU
|
, GEval.BLEU
|
||||||
|
, GEval.PrecisionRecall
|
||||||
build-depends: base >= 4.7 && < 5
|
build-depends: base >= 4.7 && < 5
|
||||||
, cond
|
, cond
|
||||||
, conduit
|
, conduit
|
||||||
@ -33,6 +34,7 @@ library
|
|||||||
, split
|
, split
|
||||||
, text
|
, text
|
||||||
, unix
|
, unix
|
||||||
|
, fgl
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
executable geval
|
executable geval
|
||||||
@ -42,6 +44,7 @@ executable geval
|
|||||||
build-depends: base
|
build-depends: base
|
||||||
, geval
|
, geval
|
||||||
, optparse-applicative
|
, optparse-applicative
|
||||||
|
, fgl
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
test-suite geval-test
|
test-suite geval-test
|
||||||
|
@ -1,46 +1,86 @@
|
|||||||
module GEval.PrecisionRecall
|
{-# LANGUAGE PartialTypeSignatures #-}
|
||||||
|
|
||||||
|
module GEval.PrecisionRecall(fMeasure, f1Measure, f2Measure, precision, recall,
|
||||||
|
fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts,
|
||||||
|
precisionAndRecall, precisionAndRecallFromCounts)
|
||||||
where
|
where
|
||||||
|
|
||||||
|
import Data.Graph.Inductive
|
||||||
|
import Data.Graph.Inductive.Query.MaxFlow
|
||||||
|
|
||||||
|
import Debug.Trace
|
||||||
|
|
||||||
f2Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
f2Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
f2Measure = fMeasure 2.0
|
f2Measure = fMeasure 2.0
|
||||||
|
|
||||||
|
f1Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
|
f1Measure = fMeasure 1.0
|
||||||
|
|
||||||
fMeasure :: Double -> (a -> b -> Bool) -> [a] -> [b] -> Double
|
fMeasure :: Double -> (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
fMeasure beta matchingFun expected got =
|
fMeasure beta matchingFun expected got =
|
||||||
(1 + betaSquared) * (p + r) / (betaSquared * p + 2)
|
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||||
where betaSquared = beta ^ 2
|
where betaSquared = beta ^ 2
|
||||||
(p, r) = precisionAndRecall matchingFun expected got
|
(p, r) = precisionAndRecall matchingFun expected got
|
||||||
|
|
||||||
|
f2MeasureOnCounts :: (Int, Int, Int) -> Double
|
||||||
|
f2MeasureOnCounts = fMeasureOnCounts 2.0
|
||||||
|
|
||||||
|
f1MeasureOnCounts :: (Int, Int, Int) -> Double
|
||||||
|
f1MeasureOnCounts = fMeasureOnCounts 1.0
|
||||||
|
|
||||||
|
fMeasureOnCounts :: Double -> (Int, Int, Int) -> Double
|
||||||
|
fMeasureOnCounts beta (tp, nbExpected, nbGot) =
|
||||||
|
(1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r)
|
||||||
|
where betaSquared = beta ^ 2
|
||||||
|
(p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot)
|
||||||
|
|
||||||
precisionAndRecall :: (a -> b -> Bool) -> [a] -> [b] -> (Double, Double)
|
precisionAndRecall :: (a -> b -> Bool) -> [a] -> [b] -> (Double, Double)
|
||||||
precisionAndRecall matchFun expected got = precisionAndRecallFromCounts cs
|
precisionAndRecall matchFun expected got
|
||||||
where cs = counts matchFun expected got
|
= precisionAndRecallFromCounts (tp, length expected, length got)
|
||||||
|
where tp = maxMatch matchFun expected got
|
||||||
|
|
||||||
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
|
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
|
||||||
precisionAndRecallFromCounts (fp, fn, tp) = (tp / (tp+fp+epsilon), tp / (tp+fn+epsilon))
|
precisionAndRecallFromCounts (tp, nbExpected, nbGot) =
|
||||||
where epsilon = 0.000001
|
(tp `safeDiv` nbGot, tp `safeDiv` nbExpected)
|
||||||
|
|
||||||
|
-- division with possible zero
|
||||||
|
safeDiv :: Int -> Int -> Double
|
||||||
|
safeDiv _ 0 = 1.0
|
||||||
|
safeDiv x y = (fromIntegral x) / (fromIntegral y)
|
||||||
|
|
||||||
|
safeDoubleDiv :: Double -> Double -> Double
|
||||||
|
safeDoubleDiv _ 0.0 = 0.0
|
||||||
|
safeDoubleDiv x y = x / y
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
precision = first . precisionAndRecall
|
precision matchFun expected got = fst $ precisionAndRecall matchFun expected got
|
||||||
|
|
||||||
recall :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
recall :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
recall = second . precisionAndRecall
|
recall matchFun expected got = snd $ precisionAndRecall matchFun expected got
|
||||||
|
|
||||||
counts :: (a -> b -> Bool) -> [a] -> [b] -> (Int, Int, Int)
|
-- counting maximum match with maximum bipartite matching
|
||||||
counts matchFun expected got = f
|
-- (we build an auxiliary graph and do a max-flow on this)
|
||||||
|
maxMatch :: (a -> b -> Bool) -> [a] -> [b] -> Int
|
||||||
|
maxMatch matchFun expected got = mf
|
||||||
|
where (b, e, g) = buildGraph matchFun expected got
|
||||||
|
mf = maxFlow g (fst b) (fst e)
|
||||||
|
|
||||||
|
|
||||||
countStep :: ([a], (Int, Int)) -> b -> ([a], (Int, Int))
|
buildGraph :: (a -> b -> Bool) -> [a] -> [b] -> (LNode Int, LNode Int, Gr Int Int)
|
||||||
countStep (expectedLeft, fp, tp) itemGot =
|
buildGraph matchFun expected got = (b, e, g)
|
||||||
(rest, case firstMatched of
|
where ((b, e), (_, g)) = buildGraph' matchFun expected got
|
||||||
Just _ -> (fp, tp+1)
|
buildGraph' matchFun expected got =
|
||||||
Nothing -> (fp+1, tp))
|
run empty $
|
||||||
where (rest, firstMatched) = findFirst ((flip matchingFun) itemGot) expectedLeft
|
do b <- insMapNodeM 0
|
||||||
|
e <- insMapNodeM 1
|
||||||
findFirst :: (a -> Bool) -> [a] -> ([a], Maybe a)
|
mapM insMapNodeM [2..1+(length expected)+(length got)]
|
||||||
findFirst _ [] -> ([], Nothing)
|
insMapEdgesM $ map (\n -> (0, n, 1)) expectedIxs
|
||||||
findFirst condition (x:xs) = if (condition x)
|
insMapEdgesM $ map (\m -> (m, 1, 1)) gotIxs
|
||||||
then
|
insMapEdgesM $ map (\(n,m) -> (n, m, 1))
|
||||||
(xs, Just x)
|
$ filter (\(n, m) -> matchFun (expected !! (n-2)) (got !! (m-2-(length expected))))
|
||||||
else
|
[(x,y) | x <- expectedIxs, y <- gotIxs]
|
||||||
(x:xs', m)
|
return (b,e)
|
||||||
where
|
where expectedIxs = [2..1+(length expected)]
|
||||||
(xs',m) = findFirst condition xs
|
gotIxs = [2+(length expected)..1+(length expected)+(length got)]
|
||||||
|
24
test/Spec.hs
24
test/Spec.hs
@ -3,6 +3,7 @@ import Test.Hspec
|
|||||||
import GEval.Core
|
import GEval.Core
|
||||||
import GEval.OptionsParser
|
import GEval.OptionsParser
|
||||||
import GEval.BLEU
|
import GEval.BLEU
|
||||||
|
import GEval.PrecisionRecall
|
||||||
import Options.Applicative
|
import Options.Applicative
|
||||||
import qualified Test.HUnit as HU
|
import qualified Test.HUnit as HU
|
||||||
|
|
||||||
@ -47,7 +48,26 @@ main = hspec $ do
|
|||||||
runGEvalTest "unexpected-data" `shouldThrow` (== UnexpectedData "input does not start with a digit")
|
runGEvalTest "unexpected-data" `shouldThrow` (== UnexpectedData "input does not start with a digit")
|
||||||
it "unwanted data is handled" $
|
it "unwanted data is handled" $
|
||||||
runGEvalTest "unwanted-data" `shouldThrow` (== UnexpectedData "number expected")
|
runGEvalTest "unwanted-data" `shouldThrow` (== UnexpectedData "number expected")
|
||||||
|
describe "precision and recall" $ do
|
||||||
|
it "null test" $ do
|
||||||
|
precision neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||||
|
recall neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||||
|
f1Measure neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0
|
||||||
|
it "basic test" $ do
|
||||||
|
precision testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.3333333333333333
|
||||||
|
recall testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.66666666666666666
|
||||||
|
f1Measure testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.444444444444444
|
||||||
|
|
||||||
|
neverMatch :: Char -> Int -> Bool
|
||||||
|
neverMatch _ _ = False
|
||||||
|
|
||||||
|
testMatchFun :: Char -> Int -> Bool
|
||||||
|
testMatchFun 'a' 1 = True
|
||||||
|
testMatchFun 'a' 2 = True
|
||||||
|
testMatchFun 'a' 3 = True
|
||||||
|
testMatchFun 'b' 1 = True
|
||||||
|
testMatchFun 'c' 1 = True
|
||||||
|
testMatchFun _ _ = False
|
||||||
|
|
||||||
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
||||||
extractVal (Right (Just val)) = return val
|
extractVal (Right (Just val)) = return val
|
||||||
@ -72,10 +92,12 @@ instance AEq Double where
|
|||||||
x =~ y = abs ( x - y ) < (1.0e-4 :: Double)
|
x =~ y = abs ( x - y ) < (1.0e-4 :: Double)
|
||||||
|
|
||||||
(@=~?) :: (Show a, AEq a) => a -> a -> HU.Assertion
|
(@=~?) :: (Show a, AEq a) => a -> a -> HU.Assertion
|
||||||
(@=~?) expected actual = expected =~ actual HU.@? assertionMsg
|
(@=~?) actual expected = expected =~ actual HU.@? assertionMsg
|
||||||
where
|
where
|
||||||
assertionMsg = "Expected : " ++ show expected ++
|
assertionMsg = "Expected : " ++ show expected ++
|
||||||
"\nActual : " ++ show actual
|
"\nActual : " ++ show actual
|
||||||
|
|
||||||
|
shouldBeAlmost got expected = got @=~? expected
|
||||||
|
|
||||||
shouldReturnAlmost :: (AEq a, Show a, Eq a) => IO a -> a -> Expectation
|
shouldReturnAlmost :: (AEq a, Show a, Eq a) => IO a -> a -> Expectation
|
||||||
shouldReturnAlmost action expected = action >>= (@=~? expected)
|
shouldReturnAlmost action expected = action >>= (@=~? expected)
|
||||||
|
Loading…
Reference in New Issue
Block a user