implement softmax in LogLossHashed
This commit is contained in:
parent
e84a14d069
commit
b058cd0095
@ -30,12 +30,22 @@ murmurSeedValue = 0
|
|||||||
parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||||
parseDistribution nbOfBits seed distroSpec =
|
parseDistribution nbOfBits seed distroSpec =
|
||||||
-- distribution spec is either...
|
-- distribution spec is either...
|
||||||
if T.any (== ':') distroSpec
|
normalizeDistribution <$> if T.any (== ':') distroSpec
|
||||||
-- a space-separated list of words with their log probs
|
-- a space-separated list of words with their log probs
|
||||||
then parseDistributionFromWordList nbOfBits seed distroSpec
|
then parseDistributionFromWordList nbOfBits seed distroSpec
|
||||||
-- a direct list of 2^nbOfBits log probs
|
-- a direct list of 2^nbOfBits log probs
|
||||||
else parseDistributionFromLogProbList nbOfBits distroSpec
|
else parseDistributionFromLogProbList nbOfBits distroSpec
|
||||||
|
|
||||||
|
normalizeDistribution :: HashedDistribution -> HashedDistribution
|
||||||
|
normalizeDistribution distro =
|
||||||
|
-- we do softmax (if needed)
|
||||||
|
if probSum > 1.0 || probSum < (1.0 - epsilon)
|
||||||
|
then normalized
|
||||||
|
else distro
|
||||||
|
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
|
||||||
|
normalized = V.map (\v -> log ((exp v) / probSum)) distro
|
||||||
|
epsilon = 0.00000001
|
||||||
|
|
||||||
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
||||||
|
|
||||||
data WordSpec = AnyWord | SpecificWord T.Text
|
data WordSpec = AnyWord | SpecificWord T.Text
|
||||||
|
@ -87,6 +87,8 @@ main = hspec $ do
|
|||||||
describe "LogLossHashed challenge" $ do
|
describe "LogLossHashed challenge" $ do
|
||||||
it "simple example" $ do
|
it "simple example" $ do
|
||||||
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
|
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
|
||||||
|
it "example with unnormalized values" $ do
|
||||||
|
runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887
|
||||||
describe "reading options" $ do
|
describe "reading options" $ do
|
||||||
it "can get the metric" $ do
|
it "can get the metric" $ do
|
||||||
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
||||||
|
@ -0,0 +1,2 @@
|
|||||||
|
tak:10 nie:8.9
|
||||||
|
niebieski:0 żółty:1.5 czerwony:-0.5
|
|
@ -0,0 +1 @@
|
|||||||
|
--metric LogLossHashed8
|
@ -0,0 +1,2 @@
|
|||||||
|
tak
|
||||||
|
niebieski
|
|
Loading…
Reference in New Issue
Block a user