From b058cd0095d442849f459515c0bfb7741bf1df6a Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Mon, 3 Apr 2017 10:07:58 +0200 Subject: [PATCH] implement softmax in LogLossHashed --- src/GEval/LogLossHashed.hs | 20 ++++++++++++++----- test/Spec.hs | 2 ++ .../test-A/out.tsv | 2 ++ .../log-loss-hashed-not-normalized/config.txt | 1 + .../test-A/expected.tsv | 2 ++ 5 files changed, 22 insertions(+), 5 deletions(-) create mode 100644 test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized-solution/test-A/out.tsv create mode 100644 test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/config.txt create mode 100644 test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/test-A/expected.tsv diff --git a/src/GEval/LogLossHashed.hs b/src/GEval/LogLossHashed.hs index dd96573..3e092ab 100644 --- a/src/GEval/LogLossHashed.hs +++ b/src/GEval/LogLossHashed.hs @@ -30,11 +30,21 @@ murmurSeedValue = 0 parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution parseDistribution nbOfBits seed distroSpec = -- distribution spec is either... - if T.any (== ':') distroSpec - -- a space-separated list of words with their log probs - then parseDistributionFromWordList nbOfBits seed distroSpec - -- a direct list of 2^nbOfBits log probs - else parseDistributionFromLogProbList nbOfBits distroSpec + normalizeDistribution <$> if T.any (== ':') distroSpec + -- a space-separated list of words with their log probs + then parseDistributionFromWordList nbOfBits seed distroSpec + -- a direct list of 2^nbOfBits log probs + else parseDistributionFromLogProbList nbOfBits distroSpec + +normalizeDistribution :: HashedDistribution -> HashedDistribution +normalizeDistribution distro = + -- we do softmax (if needed) + if probSum > 1.0 || probSum < (1.0 - epsilon) + then normalized + else distro + where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro + normalized = V.map (\v -> log ((exp v) / probSum)) distro + epsilon = 0.00000001 type DistroMonad s = ReaderT (VM.MVector s Double) (ST s) diff --git a/test/Spec.hs b/test/Spec.hs index 0b91148..b41722a 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -87,6 +87,8 @@ main = hspec $ do describe "LogLossHashed challenge" $ do it "simple example" $ do runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333 + it "example with unnormalized values" $ do + runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887 describe "reading options" $ do it "can get the metric" $ do extractMetric "bleu-complex" `shouldReturn` (Just BLEU) diff --git a/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized-solution/test-A/out.tsv b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized-solution/test-A/out.tsv new file mode 100644 index 0000000..02c35e7 --- /dev/null +++ b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized-solution/test-A/out.tsv @@ -0,0 +1,2 @@ +tak:10 nie:8.9 +niebieski:0 żółty:1.5 czerwony:-0.5 diff --git a/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/config.txt b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/config.txt new file mode 100644 index 0000000..b5ace1e --- /dev/null +++ b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/config.txt @@ -0,0 +1 @@ +--metric LogLossHashed8 diff --git a/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/test-A/expected.tsv b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/test-A/expected.tsv new file mode 100644 index 0000000..70d9d14 --- /dev/null +++ b/test/log-loss-hashed-not-normalized/log-loss-hashed-not-normalized/test-A/expected.tsv @@ -0,0 +1,2 @@ +tak +niebieski