start working on LogLossHashed
This commit is contained in:
parent
073d92a4e7
commit
0e9c44a5b5
@ -23,6 +23,7 @@ library
|
||||
, GEval.PrecisionRecall
|
||||
, GEval.ClusteringMetrics
|
||||
, GEval.Common
|
||||
, GEval.LogLossHashed
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, cond
|
||||
, conduit
|
||||
@ -41,6 +42,9 @@ library
|
||||
, attoparsec
|
||||
, unordered-containers
|
||||
, hashable
|
||||
, murmur3
|
||||
, vector
|
||||
, mtl
|
||||
default-language: Haskell2010
|
||||
|
||||
executable geval
|
||||
|
@ -7,6 +7,8 @@ import Data.Text
|
||||
import Control.Applicative
|
||||
import Control.Exception
|
||||
|
||||
import GEval.Common
|
||||
|
||||
newtype PageNumber = PageNumber Int
|
||||
deriving (Eq, Show)
|
||||
|
||||
@ -44,9 +46,6 @@ clippingParser = do
|
||||
lineClippingSpecsParser :: Parser [ClippingSpec]
|
||||
lineClippingSpecsParser = sepByWhitespaces clippingSpecParser
|
||||
|
||||
sepByWhitespaces :: Parser a -> Parser [a]
|
||||
sepByWhitespaces parser = possibleWhitespace *> parser `sepBy` whitespace <* possibleWhitespace <* endOfInput
|
||||
|
||||
clippingSpecParser :: Parser ClippingSpec
|
||||
clippingSpecParser = do
|
||||
pageNo <- PageNumber <$> decimal
|
||||
@ -56,10 +55,6 @@ clippingSpecParser = do
|
||||
margin <- decimal
|
||||
return $ ClippingSpec pageNo (smallerRectangle margin rectangle) (extendedRectangle margin rectangle)
|
||||
|
||||
possibleWhitespace = many' (satisfy isHorizontalSpace)
|
||||
|
||||
whitespace = many1 (satisfy isHorizontalSpace)
|
||||
|
||||
extendedRectangle :: Int -> Rectangle -> Rectangle
|
||||
extendedRectangle margin (Rectangle (Point x0 y0) (Point x1 y1)) =
|
||||
Rectangle (Point (x0 `nonNegativeDiff` margin) (y0 `nonNegativeDiff` margin))
|
||||
|
@ -1,6 +1,11 @@
|
||||
module GEval.Common
|
||||
where
|
||||
|
||||
import qualified Data.Text as T
|
||||
import Data.Text.Read as TR
|
||||
|
||||
import Data.Attoparsec.Text
|
||||
|
||||
(/.) :: (Eq a, Integral a) => a -> a -> Double
|
||||
x /. 0 = 1.0
|
||||
x /. y = (fromIntegral x) / (fromIntegral y)
|
||||
@ -17,3 +22,19 @@ entropyWithTotalGiven total distribution = - (sum $ map (entropyCount total) dis
|
||||
entropyCount :: Int -> Int -> Double
|
||||
entropyCount total count = prob * (log2 prob)
|
||||
where prob = count /. total
|
||||
|
||||
textToDouble :: T.Text -> Either String Double
|
||||
textToDouble t = case TR.double t of
|
||||
Right (x, reminder) -> if T.null reminder
|
||||
then
|
||||
Right x
|
||||
else
|
||||
Left "number text found after a number"
|
||||
Left m -> Left m
|
||||
|
||||
sepByWhitespaces :: Parser a -> Parser [a]
|
||||
sepByWhitespaces parser = possibleWhitespace *> parser `sepBy` whitespace <* possibleWhitespace <* endOfInput
|
||||
|
||||
possibleWhitespace = many' (satisfy isHorizontalSpace)
|
||||
|
||||
whitespace = many1 (satisfy isHorizontalSpace)
|
||||
|
@ -42,12 +42,21 @@ import GEval.Common
|
||||
import GEval.ClippEU
|
||||
import GEval.PrecisionRecall
|
||||
import GEval.ClusteringMetrics
|
||||
import GEval.LogLossHashed
|
||||
|
||||
import qualified Data.HashMap.Strict as M
|
||||
|
||||
import Data.Word
|
||||
|
||||
type MetricValue = Double
|
||||
|
||||
data Metric = RMSE | MSE | BLEU | Accuracy | ClippEU | FMeasure Double | NMI
|
||||
defaultLogLossHashedSize :: Word32
|
||||
defaultLogLossHashedSize = 12
|
||||
|
||||
defaultLogLossSeed :: Word32
|
||||
defaultLogLossSeed = 0
|
||||
|
||||
data Metric = RMSE | MSE | BLEU | Accuracy | ClippEU | FMeasure Double | NMI | LogLossHashed Word32
|
||||
deriving (Eq)
|
||||
|
||||
instance Show Metric where
|
||||
@ -58,6 +67,12 @@ instance Show Metric where
|
||||
show ClippEU = "ClippEU"
|
||||
show (FMeasure beta) = "F" ++ (show beta)
|
||||
show NMI = "NMI"
|
||||
show (LogLossHashed nbOfBits) = "LogLossHashed" ++ (if
|
||||
nbOfBits == defaultLogLossHashedSize
|
||||
then
|
||||
""
|
||||
else
|
||||
(show nbOfBits))
|
||||
|
||||
instance Read Metric where
|
||||
readsPrec _ ('R':'M':'S':'E':theRest) = [(RMSE, theRest)]
|
||||
@ -69,6 +84,9 @@ instance Read Metric where
|
||||
readsPrec p ('F':theRest) = case readsPrec p theRest of
|
||||
[(beta, theRest)] -> [(FMeasure beta, theRest)]
|
||||
_ -> []
|
||||
readsPrec p ('L':'o':'g':'L':'o':'s':'s':'H':'a':'s':'h':'e':'d':theRest) = case readsPrec p theRest of
|
||||
[(nbOfBits, theRest)] -> [(LogLossHashed nbOfBits, theRest)]
|
||||
_ -> [(LogLossHashed defaultLogLossHashedSize, theRest)]
|
||||
|
||||
|
||||
data MetricOrdering = TheLowerTheBetter | TheHigherTheBetter
|
||||
@ -81,6 +99,7 @@ getMetricOrdering Accuracy = TheHigherTheBetter
|
||||
getMetricOrdering ClippEU = TheHigherTheBetter
|
||||
getMetricOrdering (FMeasure _) = TheHigherTheBetter
|
||||
getMetricOrdering NMI = TheHigherTheBetter
|
||||
getMetricOrdering (LogLossHashed _) = TheLowerTheBetter
|
||||
|
||||
defaultOutDirectory = "."
|
||||
defaultTestName = "test-A"
|
||||
@ -229,6 +248,14 @@ gevalCore' ClippEU = gevalCore'' parseClippingSpecs parseClippings matchStep cli
|
||||
|
||||
gevalCore' NMI = gevalCore'' id id id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix
|
||||
|
||||
gevalCore' (LogLossHashed nbOfBits) =
|
||||
gevalCore'' id (parseDistributionWrapper nbOfBits defaultLogLossSeed) (\(t,d) -> calculateLogLoss nbOfBits defaultLogLossSeed t d) averageC negate
|
||||
|
||||
parseDistributionWrapper :: Word32 -> Word32 -> Text -> HashedDistribution
|
||||
parseDistributionWrapper nbOfBits seed distroSpec = case parseDistribution nbOfBits seed distroSpec of
|
||||
Right distro -> distro
|
||||
Left m -> throw $ UnexpectedData m
|
||||
|
||||
data SourceItem a = Got a | Done
|
||||
|
||||
gevalCore'' :: (Text -> a) -> (Text -> b) -> ((a, b) -> c) -> (Sink c (ResourceT IO) d) -> (d -> Double) -> String -> String -> IO (MetricValue)
|
||||
|
125
src/GEval/LogLossHashed.hs
Normal file
125
src/GEval/LogLossHashed.hs
Normal file
@ -0,0 +1,125 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module GEval.LogLossHashed
|
||||
(HashedDistribution, parseDistribution, calculateLogLoss)
|
||||
where
|
||||
|
||||
import qualified Data.Vector as V
|
||||
import Control.Monad.Reader
|
||||
import qualified Data.Vector.Mutable as VM
|
||||
import Control.Monad.ST
|
||||
|
||||
import Data.Hash.Murmur
|
||||
|
||||
import Data.Either
|
||||
|
||||
import Data.Bits
|
||||
import Data.Word
|
||||
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Text.Encoding as TE
|
||||
import qualified Data.Text.Read as TR
|
||||
|
||||
import GEval.Common (textToDouble)
|
||||
|
||||
type HashedDistribution = V.Vector Double
|
||||
|
||||
murmurSeedValue :: Word32
|
||||
murmurSeedValue = 0
|
||||
|
||||
parseDistribution :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||
parseDistribution nbOfBits seed distroSpec =
|
||||
-- distribution spec is either...
|
||||
if T.any (== ':') distroSpec
|
||||
-- a space-separated list of words with their log probs
|
||||
then parseDistributionFromWordList nbOfBits seed distroSpec
|
||||
-- a direct list of 2^nbOfBits log probs
|
||||
else parseDistributionFromLogProbList nbOfBits distroSpec
|
||||
|
||||
type DistroMonad s = ReaderT (VM.MVector s Double) (ST s)
|
||||
|
||||
data WordSpec = AnyWord | SpecificWord T.Text
|
||||
data WordSpecWithLogProb = WordSpecWithLogProb WordSpec Double
|
||||
|
||||
parseDistributionFromWordList :: Word32 -> Word32 -> T.Text -> Either String HashedDistribution
|
||||
parseDistributionFromWordList nbOfBits seed distroSpec = (parseDistributionFromWordList' nbOfBits seed) =<< (
|
||||
processEithers $ map getWordSpecWithLogProb $ T.splitOn " " distroSpec)
|
||||
|
||||
getWordSpecWithLogProb :: T.Text -> Either String WordSpecWithLogProb
|
||||
getWordSpecWithLogProb t =
|
||||
if T.null wordSpecPart
|
||||
then Left "colon expected"
|
||||
else case textToDouble numberPart of
|
||||
Right lp -> Right $ WordSpecWithLogProb (getWordSpec wordSpecPart) lp
|
||||
Left m -> Left $ m
|
||||
where
|
||||
-- colon _can_ be used as a part of a word, the only character forbidden is a space
|
||||
(wordSpecPart, numberPart) = T.breakOnEnd ":" t
|
||||
-- specification like ":-1.23456" means we add probability for OOVs (or actually any word)
|
||||
getWordSpec ":" = AnyWord
|
||||
getWordSpec ws = SpecificWord $ T.dropEnd 1 ws
|
||||
|
||||
|
||||
parseDistributionFromWordList' :: Word32 -> Word32 -> [WordSpecWithLogProb] -> Either String HashedDistribution
|
||||
parseDistributionFromWordList' nbOfBits seed specs = runST $ do
|
||||
emp <- VM.replicate (hashDistributionSize nbOfBits) minusInfinity
|
||||
runReaderT (addSpecs nbOfBits seed specs) emp
|
||||
frozen <- V.freeze emp
|
||||
return $ Right frozen
|
||||
|
||||
addSpecs :: Word32 -> Word32 -> [WordSpecWithLogProb] -> DistroMonad s ()
|
||||
addSpecs nbOfBits seed = mapM_ (updateDistro nbOfBits seed)
|
||||
|
||||
updateDistro :: Word32 -> Word32 -> WordSpecWithLogProb -> DistroMonad s ()
|
||||
updateDistro nbOfBits seed (WordSpecWithLogProb (SpecificWord w) logProb) = do
|
||||
d <- ask
|
||||
let fp = getFingerprint nbOfBits seed w
|
||||
updateDistro' logProb fp
|
||||
updateDistro nbOfBits _ (WordSpecWithLogProb AnyWord logProb) = do
|
||||
d <- ask
|
||||
-- spread probability uniformly
|
||||
mapM_ (updateDistro' (logProb - (log $ fromIntegral hSize))) [0..(hSize-1)]
|
||||
where hSize = hashDistributionSize nbOfBits
|
||||
|
||||
updateDistro' :: Double -> Int -> DistroMonad s ()
|
||||
updateDistro' logDelta ix = do
|
||||
d <- ask
|
||||
currentLogProb <- lift $ VM.read d ix
|
||||
let newLogProb = log ((exp currentLogProb) + (exp logDelta))
|
||||
lift $ VM.write d ix newLogProb
|
||||
|
||||
minusInfinity :: Double
|
||||
minusInfinity = log 0
|
||||
|
||||
parseDistributionFromLogProbList :: Word32 -> T.Text -> Either String HashedDistribution
|
||||
parseDistributionFromLogProbList nbOfBits distroSpec = case parseDistributionFromLogProbList' distroSpec of
|
||||
v@(Right d) -> if length d == hashDistributionSize nbOfBits
|
||||
then v
|
||||
else Left "unexpected number of log probs"
|
||||
e@(Left _) -> e
|
||||
|
||||
parseDistributionFromLogProbList' :: T.Text -> Either String HashedDistribution
|
||||
parseDistributionFromLogProbList' distroSpec =
|
||||
V.fromList <$> (processEithers $ map textToDouble $ T.splitOn " " distroSpec)
|
||||
|
||||
|
||||
processEithers :: [Either a b] -> Either a [b]
|
||||
processEithers es = case leftOnes of
|
||||
(firstLeftOne:_) -> Left firstLeftOne
|
||||
[] -> Right rightOnes
|
||||
where (leftOnes, rightOnes) = partitionEithers es
|
||||
|
||||
calculateLogLoss :: Word32 -> Word32 -> T.Text -> HashedDistribution -> Double
|
||||
calculateLogLoss nbOfBits seed term distrib = distrib V.! (getFingerprint nbOfBits seed term)
|
||||
|
||||
|
||||
hashDistributionSize :: Word32 -> Int
|
||||
hashDistributionSize nbOfBits = (shift 1 (fromIntegral nbOfBits))
|
||||
|
||||
getFingerprint :: Word32 -> Word32 -> T.Text -> Int
|
||||
getFingerprint nbOfBits seed text = fromIntegral $ h `mod` (fromIntegral $ hashDistributionSize nbOfBits)
|
||||
where h = getHash seed text
|
||||
|
||||
getHash :: Word32 -> T.Text -> Word32
|
||||
getHash seed t = murmur3 seed bs
|
||||
where bs = TE.encodeUtf8 t
|
@ -1,5 +1,5 @@
|
||||
flags: {}
|
||||
packages:
|
||||
- '.'
|
||||
extra-deps: [cond-0.4.1.1]
|
||||
extra-deps: [cond-0.4.1.1,murmur3-1.0.3]
|
||||
resolver: lts-6.26
|
||||
|
@ -84,6 +84,9 @@ main = hspec $ do
|
||||
describe "NMI challenge" $ do
|
||||
it "complex test" $ do
|
||||
runGEvalTest "nmi-complex" `shouldReturnAlmost` 0.36456
|
||||
describe "LogLossHashed challenge" $ do
|
||||
it "simple example" $ do
|
||||
runGEvalTest "log-loss-hashed-simple" `shouldReturnAlmost` 2.398479083333333
|
||||
describe "reading options" $ do
|
||||
it "can get the metric" $ do
|
||||
extractMetric "bleu-complex" `shouldReturn` (Just BLEU)
|
||||
|
@ -0,0 +1,3 @@
|
||||
-6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -0.916290731874155 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -2.30258509299405 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848 -6.23048144757848
|
||||
źdźbło:-0.6931471805599453 foo:-1.6094379124341003 nd:-1.2039728043259361
|
||||
złoty:-0.916290731874155 :-0.5108256237659907
|
|
@ -0,0 +1 @@
|
||||
--metric LogLossHashed8
|
@ -0,0 +1,3 @@
|
||||
ma
|
||||
źdźbło
|
||||
dolar
|
|
Loading…
Reference in New Issue
Block a user