189 lines
7.4 KiB
Haskell
189 lines
7.4 KiB
Haskell
{-# LANGUAGE OverloadedStrings #-}
|
|
|
|
module GEval.LogLossHashed
|
|
(HashedDistribution, parseDistribution, calculateLogLoss, parseWordSpecs, wordSpecToPair)
|
|
where
|
|
|
|
import qualified Data.Vector as V
|
|
import Control.Monad.Reader
|
|
import qualified Data.Vector.Mutable as VM
|
|
import Control.Monad.ST
|
|
|
|
import Data.Hash.Murmur
|
|
|
|
import Data.Either
|
|
|
|
import Data.Bits
|
|
import Data.Word
|
|
|
|
import qualified Data.Text as T
|
|
import qualified Data.Text.Encoding as TE
|
|
import qualified Data.Text.Read as TR
|
|
|
|
import GEval.Common (textToDouble)
|
|
|
|
type HashedDistribution = V.Vector Double
|
|
|
|
murmurSeedValue :: Word32
|
|
murmurSeedValue = 0
|
|
|
|
parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
|
parseDistribution nbOfBits seed distroSpec =
|
|
-- distribution spec is either...
|
|
normalizeDistribution <$> if T.any (== ':') distroSpec
|
|
-- a space-separated list of words with their log probs
|
|
then parseDistributionFromWordList nbOfBits seed distroSpec
|
|
-- a direct list of 2^nbOfBits log probs
|
|
else parseDistributionFromLogProbList nbOfBits distroSpec
|
|
|
|
isProbTotalIncorrect :: Double -> Bool
|
|
isProbTotalIncorrect probTotal = probTotal > 1.0 || probTotal < (1.0 - epsilon)
|
|
where epsilon = 0.00000001
|
|
|
|
normalizeDistribution :: HashedDistribution -> HashedDistribution
|
|
normalizeDistribution distro =
|
|
-- we do softmax (if needed)
|
|
if isProbTotalIncorrect probSum
|
|
then normalized
|
|
else distro
|
|
where probSum = V.foldl' (\s l -> (s + exp l)) 0.0 distro
|
|
normalized = V.map (\v -> log ((exp v) / probSum)) distro
|
|
|
|
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
|
|
|
data WordSpec = AnyWord | SpecificWord T.Text
|
|
deriving (Eq, Show)
|
|
|
|
isAnyWord AnyWord = True
|
|
isAnyWord _ = False
|
|
|
|
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
|
|
|
wordSpecToPair :: WordSpecWithLogProb -> Maybe (Double, T.Text)
|
|
wordSpecToPair (WordSpecWithLogProb AnyWord _) = Nothing
|
|
wordSpecToPair (WordSpecWithLogProb (SpecificWord w) lp) = Just (lp, w)
|
|
|
|
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
|
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<<
|
|
normalizeLogProbs =<<
|
|
lookForProbs =<<
|
|
(parseWordSpecs distroSpec)
|
|
|
|
parseWordSpecs :: T.Text -> Either String [WordSpecWithLogProb]
|
|
parseWordSpecs distroSpec = processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec
|
|
|
|
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
|
getWordSpecWithLogProb t =
|
|
if T.null wordSpecPart
|
|
then Left "colon expected"
|
|
else case textToDouble numberPart of
|
|
Right lp -> Right $ WordSpecWithLogProb (getWordSpec wordSpecPart) lp
|
|
Left m -> Left $ m
|
|
where
|
|
-- colon _can_ be used as a part of a word, the only character forbidden is a space
|
|
(wordSpecPart, numberPart) = T.breakOnEnd ":" t
|
|
-- specification like ":-1.23456" means we add probability for OOVs (or actually any word)
|
|
getWordSpec ":" = AnyWord
|
|
getWordSpec ws = SpecificWord $ T.dropEnd 1 ws
|
|
|
|
|
|
parseDistributionFromWordList' :: Word32 -> Word32 -> [WordSpecWithLogProb] -> Either String HashedDistribution
|
|
parseDistributionFromWordList' nbOfBits seed specs = runST $ do
|
|
emp <- VM.replicate (hashDistributionSize nbOfBits) minusInfinity
|
|
runReaderT (addSpecs nbOfBits seed specs) emp
|
|
frozen <- V.freeze emp
|
|
return $ Right frozen
|
|
|
|
lookForProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
|
|
lookForProbs specs
|
|
| areProbs specs = Right $ toLogProbs $ normalizeProbs specs
|
|
| otherwise = Right $ specs
|
|
|
|
areProbs :: [WordSpecWithLogProb] -> Bool
|
|
areProbs specs = all isProb specs && any isPositiveProb specs
|
|
where isProb (WordSpecWithLogProb _ v) = v >= 0.0 && v <= 1.0
|
|
isPositiveProb (WordSpecWithLogProb _ p) = p > 0.0 && p <= 1.0
|
|
|
|
toLogProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
|
toLogProbs = map (\(WordSpecWithLogProb w p) -> (WordSpecWithLogProb w (log p)))
|
|
|
|
normalizeLogProbs :: [WordSpecWithLogProb] -> Either String [WordSpecWithLogProb]
|
|
normalizeLogProbs specs = if isProbTotalIncorrect probTotal
|
|
&& probTotal < 1.0 && probTotal > 0.0
|
|
&& not (any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs)
|
|
&& all (\(WordSpecWithLogProb _ lp) -> lp <= 0) specs
|
|
then
|
|
Right ((WordSpecWithLogProb AnyWord (log (1-probTotal))):specs)
|
|
else
|
|
Right specs
|
|
where probTotal = sum $ map (\(WordSpecWithLogProb _ logp) -> exp logp) specs
|
|
|
|
normalizeProbs :: [WordSpecWithLogProb] -> [WordSpecWithLogProb]
|
|
normalizeProbs specs = if isProbTotalIncorrect probTotal
|
|
then
|
|
if probTotal > 1.0 || any (\(WordSpecWithLogProb w _) -> isAnyWord w) specs
|
|
then
|
|
map (\(WordSpecWithLogProb w p) -> WordSpecWithLogProb w (p / probTotal)) specs
|
|
else
|
|
((WordSpecWithLogProb AnyWord (1-probTotal)):specs)
|
|
else
|
|
specs
|
|
where probTotal = sum $ map (\(WordSpecWithLogProb _ p) -> p) specs
|
|
|
|
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
|
|
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)
|
|
|
|
updateDistro :: Word32 -> Word32 -> WordSpecWithLogProb -> DistroMonad s ()
|
|
updateDistro nbOfBits seed (WordSpecWithLogProb (SpecificWord w) logProb) = do
|
|
d <- ask
|
|
let fp = getFingerprint nbOfBits seed w
|
|
updateDistro' logProb fp
|
|
updateDistro nbOfBits _ (WordSpecWithLogProb AnyWord logProb) = do
|
|
d <- ask
|
|
-- spread probability uniformly
|
|
mapM_ (updateDistro' (logProb - (log $ fromIntegral hSize))) [0..(hSize-1)]
|
|
where hSize = hashDistributionSize nbOfBits
|
|
|
|
updateDistro' :: Double -> Int -> DistroMonad s ()
|
|
updateDistro' logDelta ix = do
|
|
d <- ask
|
|
currentLogProb <- lift $ VM.read d ix
|
|
let newLogProb = log ((exp currentLogProb) + (exp logDelta))
|
|
lift $ VM.write d ix newLogProb
|
|
|
|
minusInfinity :: Double
|
|
minusInfinity = log 0
|
|
|
|
parseDistributionFromLogProbList :: Word32 -> T.Text -> Either String HashedDistribution
|
|
parseDistributionFromLogProbList nbOfBits distroSpec = case parseDistributionFromLogProbList' distroSpec of
|
|
v@(Right d) -> if length d == hashDistributionSize nbOfBits
|
|
then v
|
|
else Left "unexpected number of log probs"
|
|
e@(Left _) -> e
|
|
|
|
parseDistributionFromLogProbList' :: T.Text -> Either String HashedDistribution
|
|
parseDistributionFromLogProbList' distroSpec =
|
|
V.fromList <$> (processEithers $ map textToDouble $ T.splitOn " " distroSpec)
|
|
|
|
|
|
processEithers :: [Either a b] -> Either a [b]
|
|
processEithers es = case leftOnes of
|
|
(firstLeftOne:_) -> Left firstLeftOne
|
|
[] -> Right rightOnes
|
|
where (leftOnes, rightOnes) = partitionEithers es
|
|
|
|
calculateLogLoss :: Word32 -> Word32 -> T.Text -> HashedDistribution -> Double
|
|
calculateLogLoss nbOfBits seed term distrib = distrib V.! (getFingerprint nbOfBits seed term)
|
|
|
|
|
|
hashDistributionSize :: Word32 -> Int
|
|
hashDistributionSize nbOfBits = (shift 1 (fromIntegral nbOfBits))
|
|
|
|
getFingerprint :: Word32 -> Word32 -> T.Text -> Int
|
|
getFingerprint nbOfBits seed text = fromIntegral $ h `mod` (fromIntegral $ hashDistributionSize nbOfBits)
|
|
where h = getHash seed text
|
|
|
|
getHash :: Word32 -> T.Text -> Word32
|
|
getHash seed t = murmur3 seed bs
|
|
where bs = TE.encodeUtf8 t
|