probs can be given for LogLossHashed

This commit is contained in:
Filip Gralinski 2018-05-15 08:07:47 +02:00
parent cea084c789
commit 06fd093349
8 changed files with 60 additions and 4 deletions

View File

@ -36,24 +36,33 @@ parseDistribution nbOfBits seed distroSpec =
-- a direct list of 2^nbOfBits log probs -- a direct list of 2^nbOfBits log probs
else parseDistributionFromLogProbList nbOfBits distroSpec else parseDistributionFromLogProbList nbOfBits distroSpec
isProbTotalIncorrect :: Double -> Bool
isProbTotalIncorrect probTotal = probTotal > 1.0 || probTotal < (1.0 - epsilon)
where epsilon = 0.00000001
normalizeDistribution :: HashedDistribution -> HashedDistribution normalizeDistribution :: HashedDistribution -> HashedDistribution
normalizeDistribution distro = normalizeDistribution distro =
-- we do softmax (if needed) -- we do softmax (if needed)
if probSum > 1.0 || probSum < (1.0 - epsilon) if isProbTotalIncorrect probSum
then normalized then normalized
else distro else distro
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
normalized = V.map (\v -> log ((exp v) / probSum)) distro normalized = V.map (\v -> log ((exp v) / probSum)) distro
epsilon = 0.00000001
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s) type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
data WordSpec = AnyWord | SpecificWord T.Text data WordSpec = AnyWord | SpecificWord T.Text
deriving (Eq, Show)
isAnyWord AnyWord = True
isAnyWord _ = False
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< ( parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec) lookForProbs =<<
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
getWordSpecWithLogProb t = getWordSpecWithLogProb t =
@ -77,6 +86,31 @@ parseDistributionFromWordList' nbOfBits seed specs = runST $ do
frozen <- V.freeze emp frozen <- V.freeze emp
return $ Right frozen return $ Right frozen
lookForProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
lookForProbs specs
| areProbs specs = Right $ toLogProbs $ normalizeProbs specs
| otherwise = Right $ specs
areProbs :: [WordSpecWithLogProb] -> Bool
areProbs specs = all isProb specs && any isPositiveProb specs
where isProb (WordSpecWithLogProb _ v) = v >= 0.0 && v <= 1.0
isPositiveProb (WordSpecWithLogProb _ p) = p > 0.0 && p <= 1.0
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
normalizeProbs specs = if isProbTotalIncorrect probTotal
then
if probTotal > 1.0 || any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs
then
map (\(WordSpecWithLogProb w p) -> WordSpecWithLogProb w (p / probTotal)) specs
else
((WordSpecWithLogProb AnyWord (1-probTotal)):specs)
else
specs
where probTotal = sum $ map (\(WordSpecWithLogProb _ p) -> p) specs
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s () addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed) addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)

View File

@ -95,6 +95,10 @@ main = hspec $ do
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333 runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
it "example with unnormalized values" $ do it "example with unnormalized values" $ do
runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887 runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887
it "with probs instead of log probs" $ do
runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392
it "with probs instead of log probs (with normalization)" $ do
runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853
describe "reading options" $ do describe "reading options" $ do
it "can get the metric" $ do it "can get the metric" $ do
extractMetric "bleu-complex" `shouldReturn` (Just BLEU) extractMetric "bleu-complex" `shouldReturn` (Just BLEU)

View File

@ -0,0 +1,3 @@
1:0.5 2:0.6 3:1.0 4:1.0 5:0.9
1:0.3 2:0.2 3:0.3
1:0.2 :0.6
1 1:0.5 2:0.6 3:1.0 4:1.0 5:0.9
2 1:0.3 2:0.2 3:0.3
3 1:0.2 :0.6

View File

@ -0,0 +1 @@
--metric LogLossHashed10

View File

@ -0,0 +1,3 @@
1
3
1
1 1
2 3
3 1

View File

@ -0,0 +1,5 @@
A:0.6 B:0.4
C:0.2 A:0.1
A:0.4 C:0.4
D:1.0
C:0.4 B:0.5 :0.1
1 A:0.6 B:0.4
2 C:0.2 A:0.1
3 A:0.4 C:0.4
4 D:1.0
5 C:0.4 B:0.5 :0.1

View File

@ -0,0 +1 @@
--metric LogLossHashed10

View File

@ -0,0 +1,5 @@
A
A
B
D
A
1 A
2 A
3 B
4 D
5 A