diff --git a/src/GEval/CreateChallenge.hs b/src/GEval/CreateChallenge.hs index 09a85e4..3e00993 100644 --- a/src/GEval/CreateChallenge.hs +++ b/src/GEval/CreateChallenge.hs @@ -125,8 +125,12 @@ must be given: where *logprobi* is the logarithm of the probability for *wordi* and *logprob0* is the logarithm of the probability mass for all the other words (it will be spread between all 1024 fingerprint values). If the -respective probabilities do not sum up to 1, they will be normalised with -softmax. +respective probabilities do not sum up to 1: + + * if the sum is larger than 0.0 and smaller than 1.0, and no logprob0 + is given, log of the remaining probablity mass will be assigned to logprob0, + * otherwise they will be normalised with. +softmax Note: the separator here is space, not TAB! diff --git a/src/GEval/LogLossHashed.hs b/src/GEval/LogLossHashed.hs index bed3c7b..a4388ee 100644 --- a/src/GEval/LogLossHashed.hs +++ b/src/GEval/LogLossHashed.hs @@ -61,6 +61,7 @@ data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< + normalizeLogProbs =<< lookForProbs =<< (processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec) @@ -99,6 +100,17 @@ areProbs specs = all isProb specs && any isPositiveProb specs toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb] toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p))) +normalizeLogProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb] +normalizeLogProbs specs = if isProbTotalIncorrect probTotal + && probTotal < 1.0 && probTotal > 0.0 + && not (any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs) + && all (\(WordSpecWithLogProb _ lp) -> lp <= 0) specs + then + Right ((WordSpecWithLogProb AnyWord (log (1-probTotal))):specs) + else + Right specs + where probTotal = sum $ map (\(WordSpecWithLogProb _ logp) -> exp logp) specs + normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb] normalizeProbs specs = if isProbTotalIncorrect probTotal then diff --git a/test/Spec.hs b/test/Spec.hs index 8cd27bf..831d6b5 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -100,6 +100,9 @@ main = hspec $ do runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392 it "with probs instead of log probs (with normalization)" $ do runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853 + it "with log probs whose probs are summing up to less than 1.0" $ do + runGEvalTest "log-loss-hashed-normalization" `shouldReturnAlmost` 5.16395069238851 + describe "reading options" $ do it "can get the metric" $ do extractMetric "bleu-complex" `shouldReturn` (Just BLEU) diff --git a/test/log-loss-hashed-normalization/log-loss-hashed-normalization-solution/test-A/out.tsv b/test/log-loss-hashed-normalization/log-loss-hashed-normalization-solution/test-A/out.tsv new file mode 100644 index 0000000..22e39a7 --- /dev/null +++ b/test/log-loss-hashed-normalization/log-loss-hashed-normalization-solution/test-A/out.tsv @@ -0,0 +1,3 @@ +B:-1.20397280432594 A:-0.916290731874155 +A:-2.3025850929940 C:-1.6094379124341 +A:-2.3025850929940 C:-1.6094379124341 :-0.356674943938732 diff --git a/test/log-loss-hashed-normalization/log-loss-hashed-normalization/config.txt b/test/log-loss-hashed-normalization/log-loss-hashed-normalization/config.txt new file mode 100644 index 0000000..8121298 --- /dev/null +++ b/test/log-loss-hashed-normalization/log-loss-hashed-normalization/config.txt @@ -0,0 +1 @@ +--metric LogLossHashed10 diff --git a/test/log-loss-hashed-normalization/log-loss-hashed-normalization/test-A/expected.tsv b/test/log-loss-hashed-normalization/log-loss-hashed-normalization/test-A/expected.tsv new file mode 100644 index 0000000..3042490 --- /dev/null +++ b/test/log-loss-hashed-normalization/log-loss-hashed-normalization/test-A/expected.tsv @@ -0,0 +1,3 @@ +A +B +B