accuracy can work on probablity distributions now
This commit is contained in:
parent
d370e375a0
commit
f9dfbc1466
@ -38,3 +38,7 @@ sepByWhitespaces parser = possibleWhitespace *> parser `sepBy` whitespace <* pos
|
|||||||
possibleWhitespace = many' (satisfy isHorizontalSpace)
|
possibleWhitespace = many' (satisfy isHorizontalSpace)
|
||||||
|
|
||||||
whitespace = many1 (satisfy isHorizontalSpace)
|
whitespace = many1 (satisfy isHorizontalSpace)
|
||||||
|
|
||||||
|
indicator :: Bool -> Double
|
||||||
|
indicator True = 1.0
|
||||||
|
indicator False = 0.0
|
||||||
|
@ -369,9 +369,18 @@ gevalCore' BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . D
|
|||||||
| otherwise = exp (1.0 - (r /. c))
|
| otherwise = exp (1.0 - (r /. c))
|
||||||
|
|
||||||
gevalCore' Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id
|
gevalCore' Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id
|
||||||
where hitOrMiss (exp,got) = if (normalizeProbForAccuracy exp got) == exp then 1.0 else 0.0
|
where hitOrMiss (exp, got) =
|
||||||
-- if the expected value is 0 or 1 treat values between 0.0 and 1.0 as probabilities
|
-- first try to parse what we got as a probability distribution
|
||||||
-- for the positive outcome
|
-- (like the one used for Likelikehood/LogLossHashed metric)
|
||||||
|
case parseWordSpecs got of
|
||||||
|
Right wordSpecs -> if Prelude.null pairs
|
||||||
|
then 0.0
|
||||||
|
else indicator (exp == (snd $ Prelude.maximum pairs))
|
||||||
|
where pairs = catMaybes $ Prelude.map wordSpecToPair wordSpecs
|
||||||
|
Left _ -> indicator ((normalizeProbForAccuracy exp got) == exp)
|
||||||
|
-- if the expected value is 0 or 1 treat values
|
||||||
|
-- between 0.0 and 1.0 as probabilities
|
||||||
|
-- for the positive outcome
|
||||||
normalizeProbForAccuracy :: Text -> Text -> Text
|
normalizeProbForAccuracy :: Text -> Text -> Text
|
||||||
normalizeProbForAccuracy exp got
|
normalizeProbForAccuracy exp got
|
||||||
| exp == (pack "1") = case tryReadingAsFloat got of
|
| exp == (pack "1") = case tryReadingAsFloat got of
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
{-# LANGUAGE OverloadedStrings #-}
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
module GEval.LogLossHashed
|
module GEval.LogLossHashed
|
||||||
(HashedDistribution, parseDistribution, calculateLogLoss)
|
(HashedDistribution, parseDistribution, calculateLogLoss, parseWordSpecs, wordSpecToPair)
|
||||||
where
|
where
|
||||||
|
|
||||||
import qualified Data.Vector as V
|
import qualified Data.Vector as V
|
||||||
@ -59,11 +59,18 @@ isAnyWord _ = False
|
|||||||
|
|
||||||
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
||||||
|
|
||||||
|
wordSpecToPair :: WordSpecWithLogProb -> Maybe (Double, T.Text)
|
||||||
|
wordSpecToPair (WordSpecWithLogProb AnyWord _) = Nothing
|
||||||
|
wordSpecToPair (WordSpecWithLogProb (SpecificWord w) lp) = Just (lp, w)
|
||||||
|
|
||||||
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||||
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
||||||
normalizeLogProbs =<<
|
normalizeLogProbs =<<
|
||||||
lookForProbs =<<
|
lookForProbs =<<
|
||||||
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
(parseWordSpecs distroSpec)
|
||||||
|
|
||||||
|
parseWordSpecs :: T.Text -> Either String [WordSpecWithLogProb]
|
||||||
|
parseWordSpecs distroSpec = processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec
|
||||||
|
|
||||||
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
||||||
getWordSpecWithLogProb t =
|
getWordSpecWithLogProb t =
|
||||||
|
@ -293,7 +293,8 @@ main = hspec $ do
|
|||||||
describe "handle --alt-metric option" $ do
|
describe "handle --alt-metric option" $ do
|
||||||
it "accuracy instead of likelihood" $ do
|
it "accuracy instead of likelihood" $ do
|
||||||
runGEvalTestExtraOptions ["--alt-metric", "Accuracy"] "likelihood-simple" `shouldReturnAlmost` 0.75
|
runGEvalTestExtraOptions ["--alt-metric", "Accuracy"] "likelihood-simple" `shouldReturnAlmost` 0.75
|
||||||
|
it "accuracy instead of log loss" $ do
|
||||||
|
runGEvalTestExtraOptions ["--alt-metric", "Accuracy"] "log-loss-hashed-probs" `shouldReturnAlmost` 0.4
|
||||||
|
|
||||||
neverMatch :: Char -> Int -> Bool
|
neverMatch :: Char -> Int -> Bool
|
||||||
neverMatch _ _ = False
|
neverMatch _ _ = False
|
||||||
|
Loading…
Reference in New Issue
Block a user