diff --git a/src/GEval/Common.hs b/src/GEval/Common.hs index ad89c79..38828ab 100644 --- a/src/GEval/Common.hs +++ b/src/GEval/Common.hs @@ -38,3 +38,7 @@ sepByWhitespaces parser = possibleWhitespace *> parser `sepBy` whitespace <* pos possibleWhitespace = many' (satisfy isHorizontalSpace) whitespace = many1 (satisfy isHorizontalSpace) + +indicator :: Bool -> Double +indicator True = 1.0 +indicator False = 0.0 diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 1353aab..180cc41 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -369,9 +369,18 @@ gevalCore' BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . D | otherwise = exp (1.0 - (r /. c)) gevalCore' Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id - where hitOrMiss (exp,got) = if (normalizeProbForAccuracy exp got) == exp then 1.0 else 0.0 - -- if the expected value is 0 or 1 treat values between 0.0 and 1.0 as probabilities - -- for the positive outcome + where hitOrMiss (exp, got) = + -- first try to parse what we got as a probability distribution + -- (like the one used for Likelikehood/LogLossHashed metric) + case parseWordSpecs got of + Right wordSpecs -> if Prelude.null pairs + then 0.0 + else indicator (exp == (snd $ Prelude.maximum pairs)) + where pairs = catMaybes $ Prelude.map wordSpecToPair wordSpecs + Left _ -> indicator ((normalizeProbForAccuracy exp got) == exp) + -- if the expected value is 0 or 1 treat values + -- between 0.0 and 1.0 as probabilities + -- for the positive outcome normalizeProbForAccuracy :: Text -> Text -> Text normalizeProbForAccuracy exp got | exp == (pack "1") = case tryReadingAsFloat got of diff --git a/src/GEval/LogLossHashed.hs b/src/GEval/LogLossHashed.hs index a4388ee..5b3dfcc 100644 --- a/src/GEval/LogLossHashed.hs +++ b/src/GEval/LogLossHashed.hs @@ -1,7 +1,7 @@ {-# LANGUAGE OverloadedStrings #-} module GEval.LogLossHashed - (HashedDistribution, parseDistribution, calculateLogLoss) + (HashedDistribution, parseDistribution, calculateLogLoss, parseWordSpecs, wordSpecToPair) where import qualified Data.Vector as V @@ -59,11 +59,18 @@ isAnyWord _ = False data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double +wordSpecToPair :: WordSpecWithLogProb -> Maybe (Double, T.Text) +wordSpecToPair (WordSpecWithLogProb AnyWord _) = Nothing +wordSpecToPair (WordSpecWithLogProb (SpecificWord w) lp) = Just (lp, w) + parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< normalizeLogProbs =<< lookForProbs =<< - (processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec) + (parseWordSpecs distroSpec) + +parseWordSpecs :: T.Text -> Either String [WordSpecWithLogProb] +parseWordSpecs distroSpec = processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb getWordSpecWithLogProb t = diff --git a/test/Spec.hs b/test/Spec.hs index 0474ae2..90039bb 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -293,7 +293,8 @@ main = hspec $ do describe "handle --alt-metric option" $ do it "accuracy instead of likelihood" $ do runGEvalTestExtraOptions ["--alt-metric", "Accuracy"] "likelihood-simple" `shouldReturnAlmost` 0.75 - + it "accuracy instead of log loss" $ do + runGEvalTestExtraOptions ["--alt-metric", "Accuracy"] "log-loss-hashed-probs" `shouldReturnAlmost` 0.4 neverMatch :: Char -> Int -> Bool neverMatch _ _ = False