|
|
|
@ -36,24 +36,33 @@ parseDistribution nbOfBits seed distroSpec =
|
|
|
|
|
-- a direct list of 2^nbOfBits log probs
|
|
|
|
|
else parseDistributionFromLogProbList nbOfBits distroSpec
|
|
|
|
|
|
|
|
|
|
isProbTotalIncorrect :: Double -> Bool
|
|
|
|
|
isProbTotalIncorrect probTotal = probTotal > 1.0 || probTotal < (1.0 - epsilon)
|
|
|
|
|
where epsilon = 0.00000001
|
|
|
|
|
|
|
|
|
|
normalizeDistribution :: HashedDistribution -> HashedDistribution
|
|
|
|
|
normalizeDistribution distro =
|
|
|
|
|
-- we do softmax (if needed)
|
|
|
|
|
if probSum > 1.0 || probSum < (1.0 - epsilon)
|
|
|
|
|
if isProbTotalIncorrect probSum
|
|
|
|
|
then normalized
|
|
|
|
|
else distro
|
|
|
|
|
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
|
|
|
|
|
normalized = V.map (\v -> log ((exp v) / probSum)) distro
|
|
|
|
|
epsilon = 0.00000001
|
|
|
|
|
|
|
|
|
|
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
|
|
|
|
|
|
|
|
|
data WordSpec = AnyWord | SpecificWord T.Text
|
|
|
|
|
deriving (Eq, Show)
|
|
|
|
|
|
|
|
|
|
isAnyWord AnyWord = True
|
|
|
|
|
isAnyWord _ = False
|
|
|
|
|
|
|
|
|
|
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
|
|
|
|
|
|
|
|
|
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
|
|
|
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< (
|
|
|
|
|
processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
|
|
|
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
|
|
|
|
lookForProbs =<<
|
|
|
|
|
(processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
|
|
|
|
|
|
|
|
|
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
|
|
|
|
getWordSpecWithLogProb t =
|
|
|
|
@ -77,6 +86,31 @@ parseDistributionFromWordList' nbOfBits seed specs = runST $ do
|
|
|
|
|
frozen <- V.freeze emp
|
|
|
|
|
return $ Right frozen
|
|
|
|
|
|
|
|
|
|
lookForProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
|
|
|
|
|
lookForProbs specs
|
|
|
|
|
| areProbs specs = Right $ toLogProbs $ normalizeProbs specs
|
|
|
|
|
| otherwise = Right $ specs
|
|
|
|
|
|
|
|
|
|
areProbs :: [WordSpecWithLogProb] -> Bool
|
|
|
|
|
areProbs specs = all isProb specs && any isPositiveProb specs
|
|
|
|
|
where isProb (WordSpecWithLogProb _ v) = v >= 0.0 && v <= 1.0
|
|
|
|
|
isPositiveProb (WordSpecWithLogProb _ p) = p > 0.0 && p <= 1.0
|
|
|
|
|
|
|
|
|
|
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
|
|
|
|
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
|
|
|
|
|
|
|
|
|
|
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
|
|
|
|
normalizeProbs specs = if isProbTotalIncorrect probTotal
|
|
|
|
|
then
|
|
|
|
|
if probTotal > 1.0 || any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs
|
|
|
|
|
then
|
|
|
|
|
map (\(WordSpecWithLogProb w p) -> WordSpecWithLogProb w (p / probTotal)) specs
|
|
|
|
|
else
|
|
|
|
|
((WordSpecWithLogProb AnyWord (1-probTotal)):specs)
|
|
|
|
|
else
|
|
|
|
|
specs
|
|
|
|
|
where probTotal = sum $ map (\(WordSpecWithLogProb _ p) -> p) specs
|
|
|
|
|
|
|
|
|
|
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
|
|
|
|
|
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)
|
|
|
|
|
|
|
|
|
|