diff --git a/geval.cabal b/geval.cabal index 1007b68..5176f1a 100644 --- a/geval.cabal +++ b/geval.cabal @@ -19,6 +19,7 @@ library GEval.CreateChallenge , GEval.OptionsParser , GEval.BLEU + , GEval.PrecisionRecall build-depends: base >= 4.7 && < 5 , cond , conduit @@ -33,6 +34,7 @@ library , split , text , unix + , fgl default-language: Haskell2010 executable geval @@ -42,6 +44,7 @@ executable geval build-depends: base , geval , optparse-applicative + , fgl default-language: Haskell2010 test-suite geval-test diff --git a/src/GEval/PrecisionRecall.hs b/src/GEval/PrecisionRecall.hs index 2872c75..2302220 100644 --- a/src/GEval/PrecisionRecall.hs +++ b/src/GEval/PrecisionRecall.hs @@ -1,46 +1,86 @@ -module GEval.PrecisionRecall +{-# LANGUAGE PartialTypeSignatures #-} + +module GEval.PrecisionRecall(fMeasure, f1Measure, f2Measure, precision, recall, + fMeasureOnCounts, f1MeasureOnCounts, f2MeasureOnCounts, + precisionAndRecall, precisionAndRecallFromCounts) where +import Data.Graph.Inductive +import Data.Graph.Inductive.Query.MaxFlow + +import Debug.Trace + f2Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double f2Measure = fMeasure 2.0 +f1Measure :: (a -> b -> Bool) -> [a] -> [b] -> Double +f1Measure = fMeasure 1.0 + fMeasure :: Double -> (a -> b -> Bool) -> [a] -> [b] -> Double fMeasure beta matchingFun expected got = - (1 + betaSquared) * (p + r) / (betaSquared * p + 2) + (1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r) where betaSquared = beta ^ 2 (p, r) = precisionAndRecall matchingFun expected got +f2MeasureOnCounts :: (Int, Int, Int) -> Double +f2MeasureOnCounts = fMeasureOnCounts 2.0 + +f1MeasureOnCounts :: (Int, Int, Int) -> Double +f1MeasureOnCounts = fMeasureOnCounts 1.0 + +fMeasureOnCounts :: Double -> (Int, Int, Int) -> Double +fMeasureOnCounts beta (tp, nbExpected, nbGot) = + (1 + betaSquared) * p * r `safeDoubleDiv` (betaSquared * p + r) + where betaSquared = beta ^ 2 + (p, r) = precisionAndRecallFromCounts (tp, nbExpected, nbGot) + precisionAndRecall :: (a -> b -> Bool) -> [a] -> [b] -> (Double, Double) -precisionAndRecall matchFun expected got = precisionAndRecallFromCounts cs - where cs = counts matchFun expected got +precisionAndRecall matchFun expected got + = precisionAndRecallFromCounts (tp, length expected, length got) + where tp = maxMatch matchFun expected got precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double) -precisionAndRecallFromCounts (fp, fn, tp) = (tp / (tp+fp+epsilon), tp / (tp+fn+epsilon)) - where epsilon = 0.000001 +precisionAndRecallFromCounts (tp, nbExpected, nbGot) = + (tp `safeDiv` nbGot, tp `safeDiv` nbExpected) + +-- division with possible zero +safeDiv :: Int -> Int -> Double +safeDiv _ 0 = 1.0 +safeDiv x y = (fromIntegral x) / (fromIntegral y) + +safeDoubleDiv :: Double -> Double -> Double +safeDoubleDiv _ 0.0 = 0.0 +safeDoubleDiv x y = x / y + + precision :: (a -> b -> Bool) -> [a] -> [b] -> Double -precision = first . precisionAndRecall +precision matchFun expected got = fst $ precisionAndRecall matchFun expected got recall :: (a -> b -> Bool) -> [a] -> [b] -> Double -recall = second . precisionAndRecall +recall matchFun expected got = snd $ precisionAndRecall matchFun expected got -counts :: (a -> b -> Bool) -> [a] -> [b] -> (Int, Int, Int) -counts matchFun expected got = f +-- counting maximum match with maximum bipartite matching +-- (we build an auxiliary graph and do a max-flow on this) +maxMatch :: (a -> b -> Bool) -> [a] -> [b] -> Int +maxMatch matchFun expected got = mf + where (b, e, g) = buildGraph matchFun expected got + mf = maxFlow g (fst b) (fst e) -countStep :: ([a], (Int, Int)) -> b -> ([a], (Int, Int)) -countStep (expectedLeft, fp, tp) itemGot = - (rest, case firstMatched of - Just _ -> (fp, tp+1) - Nothing -> (fp+1, tp)) - where (rest, firstMatched) = findFirst ((flip matchingFun) itemGot) expectedLeft - -findFirst :: (a -> Bool) -> [a] -> ([a], Maybe a) -findFirst _ [] -> ([], Nothing) -findFirst condition (x:xs) = if (condition x) - then - (xs, Just x) - else - (x:xs', m) - where - (xs',m) = findFirst condition xs +buildGraph :: (a -> b -> Bool) -> [a] -> [b] -> (LNode Int, LNode Int, Gr Int Int) +buildGraph matchFun expected got = (b, e, g) + where ((b, e), (_, g)) = buildGraph' matchFun expected got + buildGraph' matchFun expected got = + run empty $ + do b <- insMapNodeM 0 + e <- insMapNodeM 1 + mapM insMapNodeM [2..1+(length expected)+(length got)] + insMapEdgesM $ map (\n -> (0, n, 1)) expectedIxs + insMapEdgesM $ map (\m -> (m, 1, 1)) gotIxs + insMapEdgesM $ map (\(n,m) -> (n, m, 1)) + $ filter (\(n, m) -> matchFun (expected !! (n-2)) (got !! (m-2-(length expected)))) + [(x,y) | x <- expectedIxs, y <- gotIxs] + return (b,e) + where expectedIxs = [2..1+(length expected)] + gotIxs = [2+(length expected)..1+(length expected)+(length got)] diff --git a/test/Spec.hs b/test/Spec.hs index b549a05..a1324ca 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -3,6 +3,7 @@ import Test.Hspec import GEval.Core import GEval.OptionsParser import GEval.BLEU +import GEval.PrecisionRecall import Options.Applicative import qualified Test.HUnit as HU @@ -47,7 +48,26 @@ main = hspec $ do runGEvalTest "unexpected-data" `shouldThrow` (== UnexpectedData "input does not start with a digit") it "unwanted data is handled" $ runGEvalTest "unwanted-data" `shouldThrow` (== UnexpectedData "number expected") + describe "precision and recall" $ do + it "null test" $ do + precision neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0 + recall neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0 + f1Measure neverMatch ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.0 + it "basic test" $ do + precision testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.3333333333333333 + recall testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.66666666666666666 + f1Measure testMatchFun ['a', 'b', 'c'] [0, 1, 2, 3, 4, 5] `shouldBeAlmost` 0.444444444444444 +neverMatch :: Char -> Int -> Bool +neverMatch _ _ = False + +testMatchFun :: Char -> Int -> Bool +testMatchFun 'a' 1 = True +testMatchFun 'a' 2 = True +testMatchFun 'a' 3 = True +testMatchFun 'b' 1 = True +testMatchFun 'c' 1 = True +testMatchFun _ _ = False extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue extractVal (Right (Just val)) = return val @@ -72,10 +92,12 @@ instance AEq Double where x =~ y = abs ( x - y ) < (1.0e-4 :: Double) (@=~?) :: (Show a, AEq a) => a -> a -> HU.Assertion -(@=~?) expected actual = expected =~ actual HU.@? assertionMsg +(@=~?) actual expected = expected =~ actual HU.@? assertionMsg where assertionMsg = "Expected : " ++ show expected ++ "\nActual : " ++ show actual +shouldBeAlmost got expected = got @=~? expected + shouldReturnAlmost :: (AEq a, Show a, Eq a) => IO a -> a -> Expectation shouldReturnAlmost action expected = action >>= (@=~? expected)