From 6f428d64968c94485fec5b70faefcf56b60e2b7a Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Sat, 25 Mar 2017 22:11:23 +0100 Subject: [PATCH] add NMI --- src/GEval/ClusteringMetrics.hs | 33 ++++++++++++++++++++++++++++++++- src/GEval/Common.hs | 9 +++++++++ test/Spec.hs | 33 ++++++++++++++++++++++++++++----- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/src/GEval/ClusteringMetrics.hs b/src/GEval/ClusteringMetrics.hs index 9dd2fa0..1e7b886 100644 --- a/src/GEval/ClusteringMetrics.hs +++ b/src/GEval/ClusteringMetrics.hs @@ -1,5 +1,8 @@ module GEval.ClusteringMetrics - (purity, purityFromConfusionMap, updateConfusionMap) + (purity, purityFromConfusionMap, updateConfusionMap, + normalizedMutualInformation, + normalizedMutualInformationFromConfusionMatrix, + updateConfusionMatrix) where import GEval.Common @@ -26,3 +29,31 @@ updateConfusionMap :: (Hashable a, Eq a, Hashable b, Eq b) => M.HashMap b (M.Has updateConfusionMap h (e, g) = M.insertWith updateSubHash g (unitHash e) h where unitHash k = M.singleton k 1 updateSubHash uh sh = M.unionWith (+) uh sh + + +normalizedMutualInformation :: (Hashable a, Eq a, Hashable b, Eq b) => [(a, b)] -> Double +normalizedMutualInformation pL = normalizedMutualInformationFromConfusionMatrix cM + where cM = confusionMatrix pL + + +normalizedMutualInformationFromConfusionMatrix :: (Hashable a, Eq a, Hashable b, Eq b) => M.HashMap (a, b) Int -> Double +normalizedMutualInformationFromConfusionMatrix cM = 2.0 * mutualInformation / (classEntropy + clusterEntropy) + where mutualInformation = sum $ map pairMutualInformation $ M.toList cM + pairMutualInformation ((klass, cluster), count) = + (count /. total) * (log2 ((total /. (classDistribution M.! klass)) * (count /. (clusterDistribution M.! cluster)))) + total = sum $ map snd $ M.toList cM + + classEntropy = entropyWithTotalGiven total $ map snd $ M.toList classDistribution + clusterEntropy = entropyWithTotalGiven total $ map snd $ M.toList clusterDistribution + + classDistribution = getDistribution fst cM + clusterDistribution = getDistribution snd cM + + getDistribution fun cM = M.foldlWithKey' (\m kv count -> M.insertWith (+) (fun kv) count m) M.empty cM + + +confusionMatrix :: (Hashable a, Eq a, Hashable b, Eq b) => [(a, b)] -> M.HashMap (a, b) Int +confusionMatrix = foldl' updateConfusionMatrix M.empty + +updateConfusionMatrix :: (Hashable a, Eq a, Hashable b, Eq b) => M.HashMap (a, b) Int -> (a, b) -> M.HashMap (a, b) Int +updateConfusionMatrix m p = M.insertWith (+) p 1 m diff --git a/src/GEval/Common.hs b/src/GEval/Common.hs index 6b0134a..35b2b6f 100644 --- a/src/GEval/Common.hs +++ b/src/GEval/Common.hs @@ -8,3 +8,12 @@ x /. y = (fromIntegral x) / (fromIntegral y) safeDoubleDiv :: Double -> Double -> Double safeDoubleDiv _ 0.0 = 0.0 safeDoubleDiv x y = x / y + +log2 :: Double -> Double +log2 x = (log x) / (log 2.0) + +entropyWithTotalGiven total distribution = - (sum $ map (entropyCount total) distribution) + +entropyCount :: Int -> Int -> Double +entropyCount total count = prob * (log2 prob) + where prob = count /. total diff --git a/test/Spec.hs b/test/Spec.hs index 5ac7abd..884a366 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -13,6 +13,24 @@ import Options.Applicative import Data.Text import qualified Test.HUnit as HU +informationRetrievalBookExample :: [(String, Int)] +informationRetrievalBookExample = [("o", 2), ("o", 2), ("d", 2), ("x", 3), ("d", 3), + ("x", 1), ("o", 1), ("x", 1), ( "x", 1), ("x", 1), ("x", 1), + ("x", 2), ("o", 2), ("o", 2), + ("x", 3), ("d", 3), ("d", 3)] + +perfectClustering :: [(Int, Char)] +perfectClustering = [(0, 'a'), (2, 'b'), (3, 'c'), (2, 'b'), (2, 'b'), (1, 'd'), (0, 'a')] + +stupidClusteringOneBigCluster :: [(Int, Int)] +stupidClusteringOneBigCluster = [(0, 2), (2, 2), (1, 2), (2, 2), (0, 2), (0, 2), (0, 2), (0, 2), (1, 2), (1, 2)] + +stupidClusteringManySmallClusters :: [(Int, Int)] +stupidClusteringManySmallClusters = [(0, 0), (2, 1), (1, 2), (2, 3), (0, 4), (0, 5), (0, 6), (0, 7), (1, 8), (1, 9)] + + + + main :: IO () main = hspec $ do describe "root mean square error" $ do @@ -53,11 +71,16 @@ main = hspec $ do precisionCount [["foo", "baz"], ["bar"], ["baz", "xyz"]] ["foo", "bar", "foo"] `shouldBe` 2 describe "purity (in flat clustering)" $ do it "the example from Information Retrieval Book" $ do - purity [("o", 2) :: (String, Int), ("o", 2), ("d", 2), ("x", 3), ("d", 3), - ("x", 1), ("o", 1), ("x", 1), ( "x", 1), ("x", 1), ("x", 1), - ("x", 2), ("o", 2), ("o", 2), - ("x", 3), ("d", 3), ("d", 3)] `shouldBeAlmost` 0.70588 - + purity informationRetrievalBookExample `shouldBeAlmost` 0.70588 + describe "NMI (in flat clustering)" $ do + it "the example from Information Retrieval Book" $ do + normalizedMutualInformation informationRetrievalBookExample `shouldBeAlmost` 0.36456 + it "perfect clustering" $ do + normalizedMutualInformation perfectClustering `shouldBeAlmost` 1.0 + it "stupid clustering with one big cluster" $ do + normalizedMutualInformation stupidClusteringOneBigCluster `shouldBeAlmost` 0.0 + it "stupid clustering with many small clusters" $ do + normalizedMutualInformation stupidClusteringManySmallClusters `shouldBeAlmost` 0.61799 describe "reading options" $ do it "can get the metric" $ do extractMetric "bleu-complex" `shouldReturn` (Just BLEU)