From 06fd09334965a4ae2d6c05817265fa8f93535299 Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Tue, 15 May 2018 08:07:47 +0200 Subject: [PATCH] probs can be given for LogLossHashed --- src/GEval/LogLossHashed.hs | 42 +++++++++++++++++-- test/Spec.hs | 4 ++ .../test-A/out.tsv | 3 ++ .../config.txt | 1 + .../test-A/expected.tsv | 3 ++ .../test-A/out.tsv | 5 +++ .../log-loss-hashed-probs/config.txt | 1 + .../log-loss-hashed-probs/test-A/expected.tsv | 5 +++ 8 files changed, 60 insertions(+), 4 deletions(-) create mode 100644 test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized-solution/test-A/out.tsv create mode 100644 test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/config.txt create mode 100644 test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/test-A/expected.tsv create mode 100644 test/log-loss-hashed-probs/log-loss-hashed-probs-solution/test-A/out.tsv create mode 100644 test/log-loss-hashed-probs/log-loss-hashed-probs/config.txt create mode 100644 test/log-loss-hashed-probs/log-loss-hashed-probs/test-A/expected.tsv diff --git a/src/GEval/LogLossHashed.hs b/src/GEval/LogLossHashed.hs index 3e092ab..bed3c7b 100644 --- a/src/GEval/LogLossHashed.hs +++ b/src/GEval/LogLossHashed.hs @@ -36,24 +36,33 @@ parseDistribution nbOfBits seed distroSpec = -- a direct list of 2^nbOfBits log probs else parseDistributionFromLogProbList nbOfBits distroSpec +isProbTotalIncorrect :: Double -> Bool +isProbTotalIncorrect probTotal = probTotal > 1.0 || probTotal < (1.0 - epsilon) + where epsilon = 0.00000001 + normalizeDistribution :: HashedDistribution -> HashedDistribution normalizeDistribution distro = -- we do softmax (if needed) - if probSum > 1.0 || probSum < (1.0 - epsilon) + if isProbTotalIncorrect probSum then normalized else distro where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro normalized = V.map (\v -> log ((exp v) / probSum)) distro - epsilon = 0.00000001 type DistroMonad s = ReaderT (VM.MVector s Double) (ST s) data WordSpec = AnyWord | SpecificWord T.Text + deriving (Eq, Show) + +isAnyWord AnyWord = True +isAnyWord _ = False + data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution -parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< ( - processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec) +parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< + lookForProbs =<< + (processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec) getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb getWordSpecWithLogProb t = @@ -77,6 +86,31 @@ parseDistributionFromWordList' nbOfBits seed specs = runST $ do frozen <- V.freeze emp return $ Right frozen +lookForProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb] +lookForProbs specs + | areProbs specs = Right $ toLogProbs $ normalizeProbs specs + | otherwise = Right $ specs + +areProbs :: [WordSpecWithLogProb] -> Bool +areProbs specs = all isProb specs && any isPositiveProb specs + where isProb (WordSpecWithLogProb _ v) = v >= 0.0 && v <= 1.0 + isPositiveProb (WordSpecWithLogProb _ p) = p > 0.0 && p <= 1.0 + +toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb] +toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p))) + +normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb] +normalizeProbs specs = if isProbTotalIncorrect probTotal + then + if probTotal > 1.0 || any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs + then + map (\(WordSpecWithLogProb w p) -> WordSpecWithLogProb w (p / probTotal)) specs + else + ((WordSpecWithLogProb AnyWord (1-probTotal)):specs) + else + specs + where probTotal = sum $ map (\(WordSpecWithLogProb _ p) -> p) specs + addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s () addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed) diff --git a/test/Spec.hs b/test/Spec.hs index 9be9045..27b6653 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -95,6 +95,10 @@ main = hspec $ do runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333 it "example with unnormalized values" $ do runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887 + it "with probs instead of log probs" $ do + runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392 + it "with probs instead of log probs (with normalization)" $ do + runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853 describe "reading options" $ do it "can get the metric" $ do extractMetric "bleu-complex" `shouldReturn` (Just BLEU) diff --git a/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized-solution/test-A/out.tsv b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized-solution/test-A/out.tsv new file mode 100644 index 0000000..d1dbc6f --- /dev/null +++ b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized-solution/test-A/out.tsv @@ -0,0 +1,3 @@ +1:0.5 2:0.6 3:1.0 4:1.0 5:0.9 +1:0.3 2:0.2 3:0.3 +1:0.2 :0.6 diff --git a/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/config.txt b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/config.txt new file mode 100644 index 0000000..8121298 --- /dev/null +++ b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/config.txt @@ -0,0 +1 @@ +--metric LogLossHashed10 diff --git a/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/test-A/expected.tsv b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/test-A/expected.tsv new file mode 100644 index 0000000..39a8593 --- /dev/null +++ b/test/log-loss-hashed-probs-normalized/log-loss-hashed-probs-normalized/test-A/expected.tsv @@ -0,0 +1,3 @@ +1 +3 +1 diff --git a/test/log-loss-hashed-probs/log-loss-hashed-probs-solution/test-A/out.tsv b/test/log-loss-hashed-probs/log-loss-hashed-probs-solution/test-A/out.tsv new file mode 100644 index 0000000..2658337 --- /dev/null +++ b/test/log-loss-hashed-probs/log-loss-hashed-probs-solution/test-A/out.tsv @@ -0,0 +1,5 @@ +A:0.6 B:0.4 +C:0.2 A:0.1 +A:0.4 C:0.4 +D:1.0 +C:0.4 B:0.5 :0.1 diff --git a/test/log-loss-hashed-probs/log-loss-hashed-probs/config.txt b/test/log-loss-hashed-probs/log-loss-hashed-probs/config.txt new file mode 100644 index 0000000..8121298 --- /dev/null +++ b/test/log-loss-hashed-probs/log-loss-hashed-probs/config.txt @@ -0,0 +1 @@ +--metric LogLossHashed10 diff --git a/test/log-loss-hashed-probs/log-loss-hashed-probs/test-A/expected.tsv b/test/log-loss-hashed-probs/log-loss-hashed-probs/test-A/expected.tsv new file mode 100644 index 0000000..b8ccfd3 --- /dev/null +++ b/test/log-loss-hashed-probs/log-loss-hashed-probs/test-A/expected.tsv @@ -0,0 +1,5 @@ +A +A +B +D +A