implement softmax in LogLossHashed

This commit is contained in:
Filip Gralinski 2017-04-03 10:07:58 +02:00 committed by Filip Gralinski
parent e84a14d069
commit b058cd0095
5 changed files with 22 additions and 5 deletions

View File

@ -30,12 +30,22 @@ murmurSeedValue = 0
parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
parseDistribution nbOfBits seed distroSpec = parseDistribution nbOfBits seed distroSpec =
-- distribution spec is either... -- distribution spec is either...
if T.any (== ':') distroSpec normalizeDistribution <$> if T.any (== ':') distroSpec
-- a space-separated list of words with their log probs -- a space-separated list of words with their log probs
then parseDistributionFromWordList nbOfBits seed distroSpec then parseDistributionFromWordList nbOfBits seed distroSpec
-- a direct list of 2^nbOfBits log probs -- a direct list of 2^nbOfBits log probs
else parseDistributionFromLogProbList nbOfBits distroSpec else parseDistributionFromLogProbList nbOfBits distroSpec
normalizeDistribution :: HashedDistribution -> HashedDistribution
normalizeDistribution distro =
-- we do softmax (if needed)
if probSum > 1.0 || probSum < (1.0 - epsilon)
then normalized
else distro
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
normalized = V.map (\v -> log ((exp v) / probSum)) distro
epsilon = 0.00000001
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s) type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
data WordSpec = AnyWord | SpecificWord T.Text data WordSpec = AnyWord | SpecificWord T.Text

View File

@ -87,6 +87,8 @@ main = hspec $ do
describe "LogLossHashed challenge" $ do describe "LogLossHashed challenge" $ do
it "simple example" $ do it "simple example" $ do
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333 runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
it "example with unnormalized values" $ do
runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887
describe "reading options" $ do describe "reading options" $ do
it "can get the metric" $ do it "can get the metric" $ do
extractMetric "bleu-complex" `shouldReturn` (Just BLEU) extractMetric "bleu-complex" `shouldReturn` (Just BLEU)

View File

@ -0,0 +1,2 @@
tak:10 nie:8.9
niebieski:0 żółty:1.5 czerwony:-0.5
1 tak:10 nie:8.9
2 niebieski:0 żółty:1.5 czerwony:-0.5

View File

@ -0,0 +1 @@
--metric LogLossHashed8

View File

@ -0,0 +1,2 @@
tak
niebieski
1 tak
2 niebieski