probs can be given for LogLossHashed
This commit is contained in:
parent
cea084c789
commit
06fd093349
@ -36,24 +36,33 @@ parseDistribution nbOfBits seed distroSpec =
|
|||||||
-- a direct list of 2^nbOfBits log probs
|
-- a direct list of 2^nbOfBits log probs
|
||||||
else parseDistributionFromLogProbList nbOfBits distroSpec
|
else parseDistributionFromLogProbList nbOfBits distroSpec
|
||||||
|
|
||||||
|
isProbTotalIncorrect :: Double -> Bool
|
||||||
|
isProbTotalIncorrect probTotal = probTotal > 1.0 || probTotal < (1.0 - epsilon)
|
||||||
|
where epsilon = 0.00000001
|
||||||
|
|
||||||
normalizeDistribution :: HashedDistribution -> HashedDistribution
|
normalizeDistribution :: HashedDistribution -> HashedDistribution
|
||||||
normalizeDistribution distro =
|
normalizeDistribution distro =
|
||||||
-- we do softmax (if needed)
|
-- we do softmax (if needed)
|
||||||
if probSum > 1.0 || probSum < (1.0 - epsilon)
|
if isProbTotalIncorrect probSum
|
||||||
then normalized
|
then normalized
|
||||||
else distro
|
else distro
|
||||||
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
|
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
|
||||||
normalized = V.map (\v -> log ((exp v) / probSum)) distro
|
normalized = V.map (\v -> log ((exp v) / probSum)) distro
|
||||||
epsilon = 0.00000001
|
|
||||||
|
|
||||||
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
||||||
|
|
||||||
data WordSpec = AnyWord | SpecificWord T.Text
|
data WordSpec = AnyWord | SpecificWord T.Text
|
||||||
|
deriving (Eq, Show)
|
||||||
|
|
||||||
|
isAnyWord AnyWord = True
|
||||||
|
isAnyWord _ = False
|
||||||
|
|
||||||
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
||||||
|
|
||||||
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||||
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< (
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
||||||
processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
lookForProbs =<<
|
||||||
|
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
||||||
|
|
||||||
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
||||||
getWordSpecWithLogProb t =
|
getWordSpecWithLogProb t =
|
||||||
@ -77,6 +86,31 @@ parseDistributionFromWordList' nbOfBits seed specs = runST $ do
|
|||||||
frozen <- V.freeze emp
|
frozen <- V.freeze emp
|
||||||
return $ Right frozen
|
return $ Right frozen
|
||||||
|
|
||||||
|
lookForProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
|
||||||
|
lookForProbs specs
|
||||||
|
| areProbs specs = Right $ toLogProbs $ normalizeProbs specs
|
||||||
|
| otherwise = Right $ specs
|
||||||
|
|
||||||
|
areProbs :: [WordSpecWithLogProb] -> Bool
|
||||||
|
areProbs specs = all isProb specs && any isPositiveProb specs
|
||||||
|
where isProb (WordSpecWithLogProb _ v) = v >= 0.0 && v <= 1.0
|
||||||
|
isPositiveProb (WordSpecWithLogProb _ p) = p > 0.0 && p <= 1.0
|
||||||
|
|
||||||
|
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
||||||
|
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
|
||||||
|
|
||||||
|
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
||||||
|
normalizeProbs specs = if isProbTotalIncorrect probTotal
|
||||||
|
then
|
||||||
|
if probTotal > 1.0 || any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs
|
||||||
|
then
|
||||||
|
map (\(WordSpecWithLogProb w p) -> WordSpecWithLogProb w (p / probTotal)) specs
|
||||||
|
else
|
||||||
|
((WordSpecWithLogProb AnyWord (1-probTotal)):specs)
|
||||||
|
else
|
||||||
|
specs
|
||||||
|
where probTotal = sum $ map (\(WordSpecWithLogProb _ p) -> p) specs
|
||||||
|
|
||||||
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
|
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
|
||||||
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)
|
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)
|
||||||
|
|
||||||
|
@ -95,6 +95,10 @@ main = hspec $ do
|
|||||||
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
|
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
|
||||||
it "example with unnormalized values" $ do
|
it "example with unnormalized values" $ do
|
||||||
runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887
|
runGEvalTest "log-loss-hashed-not-normalized" `shouldReturnAlmost` 1.0468455186722887
|
||||||
|
it "with probs instead of log probs" $ do
|
||||||
|
runGEvalTest "log-loss-hashed-probs" `shouldReturnAlmost` 4.11631293099392
|
||||||
|
it "with probs instead of log probs (with normalization)" $ do
|
||||||
|
runGEvalTest "log-loss-hashed-probs-normalized" `shouldReturnAlmost` 1.55537749098853
|
||||||
describe "reading options" $ do
|
describe "reading options" $ do
|
||||||
it "can get the metric" $ do
|
it "can get the metric" $ do
|
||||||
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
||||||
|
@ -0,0 +1,3 @@
|
|||||||
|
1:0.5 2:0.6 3:1.0 4:1.0 5:0.9
|
||||||
|
1:0.3 2:0.2 3:0.3
|
||||||
|
1:0.2 :0.6
|
|
@ -0,0 +1 @@
|
|||||||
|
--metric LogLossHashed10
|
@ -0,0 +1,3 @@
|
|||||||
|
1
|
||||||
|
3
|
||||||
|
1
|
|
@ -0,0 +1,5 @@
|
|||||||
|
A:0.6 B:0.4
|
||||||
|
C:0.2 A:0.1
|
||||||
|
A:0.4 C:0.4
|
||||||
|
D:1.0
|
||||||
|
C:0.4 B:0.5 :0.1
|
|
@ -0,0 +1 @@
|
|||||||
|
--metric LogLossHashed10
|
@ -0,0 +1,5 @@
|
|||||||
|
A
|
||||||
|
A
|
||||||
|
B
|
||||||
|
D
|
||||||
|
A
|
|
Loading…
Reference in New Issue
Block a user