log probs
This commit is contained in:
parent
82e794ae3c
commit
b01f9439b7
@ -125,8 +125,12 @@ must be given:
|
|||||||
where *logprobi* is the logarithm of the probability for *wordi* and
|
where *logprobi* is the logarithm of the probability for *wordi* and
|
||||||
*logprob0* is the logarithm of the probability mass for all the other
|
*logprob0* is the logarithm of the probability mass for all the other
|
||||||
words (it will be spread between all 1024 fingerprint values). If the
|
words (it will be spread between all 1024 fingerprint values). If the
|
||||||
respective probabilities do not sum up to 1, they will be normalised with
|
respective probabilities do not sum up to 1:
|
||||||
softmax.
|
|
||||||
|
* if the sum is larger than 0.0 and smaller than 1.0, and no logprob0
|
||||||
|
is given, log of the remaining probablity mass will be assigned to logprob0,
|
||||||
|
* otherwise they will be normalised with.
|
||||||
|
softmax
|
||||||
|
|
||||||
Note: the separator here is space, not TAB!
|
Note: the separator here is space, not TAB!
|
||||||
|
|
||||||
|
@ -61,6 +61,7 @@ data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
|||||||
|
|
||||||
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||||
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
||||||
|
normalizeLogProbs =<<
|
||||||
lookForProbs =<<
|
lookForProbs =<<
|
||||||
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
||||||
|
|
||||||
@ -99,6 +100,17 @@ areProbs specs = all isProb specs && any isPositiveProb specs
|
|||||||
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
||||||
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
|
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
|
||||||
|
|
||||||
|
normalizeLogProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
|
||||||
|
normalizeLogProbs specs = if isProbTotalIncorrect probTotal
|
||||||
|
&& probTotal < 1.0 && probTotal > 0.0
|
||||||
|
&& not (any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs)
|
||||||
|
&& all (\(WordSpecWithLogProb _ lp) -> lp <= 0) specs
|
||||||
|
then
|
||||||
|
Right ((WordSpecWithLogProb AnyWord (log (1-probTotal))):specs)
|
||||||
|
else
|
||||||
|
Right specs
|
||||||
|
where probTotal = sum $ map (\(WordSpecWithLogProb _ logp) -> exp logp) specs
|
||||||
|
|
||||||
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
||||||
normalizeProbs specs = if isProbTotalIncorrect probTotal
|
normalizeProbs specs = if isProbTotalIncorrect probTotal
|
||||||
then
|
then
|
||||||
|
@ -100,6 +100,9 @@ main = hspec $ do
|
|||||||
runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392
|
runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392
|
||||||
it "with probs instead of log probs (with normalization)" $ do
|
it "with probs instead of log probs (with normalization)" $ do
|
||||||
runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853
|
runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853
|
||||||
|
it "with log probs whose probs are summing up to less than 1.0" $ do
|
||||||
|
runGEvalTest "log-loss-hashed-normalization" `shouldReturnAlmost` 5.16395069238851
|
||||||
|
|
||||||
describe "reading options" $ do
|
describe "reading options" $ do
|
||||||
it "can get the metric" $ do
|
it "can get the metric" $ do
|
||||||
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
B:-1.20397280432594 A:-0.916290731874155
|
||||||
|
A:-2.3025850929940 C:-1.6094379124341
|
||||||
|
A:-2.3025850929940 C:-1.6094379124341 :-0.356674943938732
|
|
@ -0,0 +1 @@
|
|||||||
|
--metric LogLossHashed10
|
@ -0,0 +1,3 @@
|
|||||||
|
A
|
||||||
|
B
|
||||||
|
B
|
|
Loading…
Reference in New Issue
Block a user