Refactor towards dependent types

This commit is contained in:
Filip Gralinski 2019-09-10 06:39:31 +02:00
parent 601dbcf677
commit f68bd8df74

View File

@ -459,6 +459,8 @@ isEmptyFileSource _ = return False
logLossToLikehood logLoss = exp (-logLoss) logLossToLikehood logLoss = exp (-logLoss)
data LineInFile = LineInFile SourceSpec Word32 Text
-- | Runs evaluation for a given metric using the sources given -- | Runs evaluation for a given metric using the sources given
-- for input, expected output and output. Returns the metric value. -- for input, expected output and output. Returns the metric value.
-- Throws @GEvalException@ if something was wrong in the data (e.g. -- Throws @GEvalException@ if something was wrong in the data (e.g.
@ -467,69 +469,43 @@ logLossToLikehood logLoss = exp (-logLoss)
-- The difference between this and @gevalCore@ is that it operates on Conduit -- The difference between this and @gevalCore@ is that it operates on Conduit
-- sources (rather than source specification). -- sources (rather than source specification).
-- --
-- This could be specialised for particular metrics, if they could be -- Metrics are starting to be really defined here, though when the
-- calculated from other metrics in a trivial fashion (e.g. @RMSE@ -- input is not needed for doing the evaluation (which is not in most
-- which is just a square root of @MSE@). Otherwise a metric should be -- cases), the work is delegated to @gevalCoreWithoutInput@ function.
-- defined in @gevalCore'@ and @gevalCoreWithoutInput@ helper
-- functions.
gevalCoreOnSources :: (MonadIO m, MonadThrow m, MonadUnliftIO m) => gevalCoreOnSources :: (MonadIO m, MonadThrow m, MonadUnliftIO m) =>
Metric -- ^ evaluation metric Metric -- ^ evaluation metric
-> LineSource (ResourceT m) -- ^ source of the input values -> LineSource (ResourceT m) -- ^ source of the input values
-> LineSource (ResourceT m) -- ^ source to read the expected output -> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output -> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricOutput) -- ^ metric values for the output against the expected output -> m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreOnSources RMSE inputLineSource expectedLineSource outLineSource = do gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC logLossToLikehood noGraph
MetricOutput mse g <- gevalCoreOnSources MSE inputLineSource expectedLineSource outLineSource
return $ MetricOutput (mse ** 0.5) g
gevalCoreOnSources Likelihood inputLineSource expectedLineSource outLineSource = do gevalCoreOnSources (LikelihoodHashed nbOfBits) _ = helperLogLossHashed nbOfBits logLossToLikehood
MetricOutput logLoss g <- gevalCoreOnSources LogLoss inputLineSource expectedLineSource outLineSource
return $ MetricOutput (logLossToLikehood logLoss) g
gevalCoreOnSources (LikelihoodHashed b) inputLineSource expectedLineSource outLineSource = do gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords
MetricOutput logLoss g <- gevalCoreOnSources (LogLossHashed b) inputLineSource expectedLineSource outLineSource (Right . parseIntoProbList)
return $ MetricOutput (logLossToLikehood logLoss) g (uncurry countLogLossOnProbList)
averageC
logLossToLikehood
noGraph
where
intoWords = Right . Data.Text.words
gevalCoreOnSources MultiLabelLikelihood inputLineSource expectedLineSource outLineSource = do gevalCoreOnSources MSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC id noGraph
MetricOutput logLoss g <- gevalCoreOnSources MultiLabelLogLoss inputLineSource expectedLineSource outLineSource
return $ MetricOutput (logLossToLikehood logLoss) g
gevalCoreOnSources metric inputLineSource expectedLineSource outLineSource = do gevalCoreOnSources RMSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC (** 0.5) noGraph
gevalCore' metric inputLineSource expectedLineSource outLineSource
data LineInFile = LineInFile SourceSpec Word32 Text gevalCoreOnSources MAE _ = gevalCoreWithoutInput doubleParser doubleParser itemAbsoluteError averageC id noGraph
-- | Runs evaluation for a given metric using the sources given gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput doubleParser doubleParser smape averageC (* 100.0) noGraph
-- for input, expected output and output. Returns the metric value. where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out))
-- Throws @GEvalException@ if something was wrong in the data (e.g.
-- inconsistent number of lines in the sources).
--
-- Metrics are starting to be really defined here, though when the
-- input is not needed for doing the evaluation (which is not in most
-- cases), the work is delegated to @gevalCoreWithoutInput@ function.
gevalCore' :: (MonadIO m, MonadThrow m, MonadUnliftIO m) =>
Metric -- ^ evaluation metric
-> LineSource (ResourceT m) -- ^ source of the input values
-> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCore' MSE _ = gevalCoreWithoutInput outParser outParser itemSquaredError averageC id noGraph
where outParser = getValue . TR.double
gevalCore' MAE _ = gevalCoreWithoutInput outParser outParser itemAbsoluteError averageC id noGraph gevalCoreOnSources Pearson _ = gevalCoreByCorrelationMeasure pearson
where outParser = getValue . TR.double gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman
gevalCore' SMAPE _ = gevalCoreWithoutInput outParser outParser smape averageC (* 100.0) noGraph gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC id noGraph
where outParser = getValue . TR.double
smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out))
gevalCore' Pearson _ = gevalCoreByCorrelationMeasure pearson gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
gevalCore' Spearman _ = gevalCoreByCorrelationMeasure spearman
gevalCore' LogLoss _ = gevalCoreWithoutInput outParser outParser itemLogLossError averageC id noGraph
where outParser = getValue . TR.double
gevalCore' BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl) where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl)
bleuCombine (refs, sen) = bleuStep refs sen bleuCombine (refs, sen) = bleuStep refs sen
bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0)
@ -539,15 +515,15 @@ gevalCore' BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . D
| c == 0 && r > 0 = 0.0 | c == 0 && r > 0 = 0.0
| otherwise = exp (1.0 - (r /. c)) | otherwise = exp (1.0 - (r /. c))
gevalCore' GLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph gevalCoreOnSources GLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph
where gleuFinal (m, t) = m /. t where gleuFinal (m, t) = m /. t
gleuCombine (refs, sen) = gleuStep refs sen gleuCombine (refs, sen) = gleuStep refs sen
gleuAgg = CC.foldl gleuFuse (0, 0) gleuAgg = CC.foldl gleuFuse (0, 0)
gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2) gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2)
gevalCore' WER _ = gevalCoreWithoutInput (Right . Prelude.words . unpack) (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph gevalCoreOnSources WER _ = gevalCoreWithoutInput (Right . Prelude.words . unpack) (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph
gevalCore' Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id noGraph gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id noGraph
where hitOrMiss (exp, got) = where hitOrMiss (exp, got) =
-- first try to parse what we got as a probability distribution -- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric) -- (like the one used for Likelikehood/LogLossHashed metric)
@ -572,7 +548,7 @@ gevalCore' Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hi
tryReadingAsFloat :: Text -> Maybe Float tryReadingAsFloat :: Text -> Maybe Float
tryReadingAsFloat = readMaybe . unpack tryReadingAsFloat = readMaybe . unpack
gevalCore' (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount countAgg (fMeasureOnCounts beta) noGraph gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount countAgg (fMeasureOnCounts beta) noGraph
where outParser = detected <=< (getValue . TR.double) where outParser = detected <=< (getValue . TR.double)
expParser = expected <=< (getValue . TR.decimal) expParser = expected <=< (getValue . TR.decimal)
expected 1 = Right True expected 1 = Right True
@ -590,7 +566,7 @@ gevalCore' (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCoun
getCount (False, True) = (0, 0, 1) getCount (False, True) = (0, 0, 1)
getCount (False, False) = (0, 0, 0) getCount (False, False) = (0, 0, 0)
gevalCore' (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip) (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip) (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph
where predicted got = where predicted got =
-- first try to parse what we got as a probability distribution -- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric) -- (like the one used for Likelikehood/LogLossHashed metric)
@ -619,7 +595,7 @@ gevalCore' (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip)
M.lookupDefault 0 c gotMap)) M.lookupDefault 0 c gotMap))
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
parseObtainedAnnotations parseObtainedAnnotations
getSoftCounts getSoftCounts
countAgg countAgg
@ -629,16 +605,16 @@ gevalCore' (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
Prelude.length expected, Prelude.length expected,
Prelude.length got) Prelude.length got)
gevalCore' (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta
intoWords intoWords
(Right . (\(ProbList es) -> es) . parseIntoProbList) (Right . (\(ProbList es) -> es) . parseIntoProbList)
where intoWords = Right . Data.Text.words where intoWords = Right . Data.Text.words
gevalCore' (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta
parseAnnotations parseAnnotations
parseObtainedAnnotations parseObtainedAnnotations
gevalCore' (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings
parseLabeledClippings parseLabeledClippings
count2DFScore count2DFScore
averageC averageC
@ -651,7 +627,7 @@ gevalCore' (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings
expArea = totalArea expected expArea = totalArea expected
gotArea = totalArea got gotArea = totalArea got
gevalCore' ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep noGraph gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep noGraph
where where
parseClippings = controlledParse lineClippingsParser parseClippings = controlledParse lineClippingsParser
parseClippingSpecs = controlledParse lineClippingSpecsParser parseClippingSpecs = controlledParse lineClippingSpecsParser
@ -661,28 +637,18 @@ gevalCore' ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings m
clippeuAgg = CC.foldl countFolder (0, 0, 0) clippeuAgg = CC.foldl countFolder (0, 0, 0)
finalStep counts = f2MeasureOnCounts counts finalStep counts = f2MeasureOnCounts counts
gevalCore' NMI _ = gevalCoreWithoutInput (Right . id) (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph gevalCoreOnSources NMI _ = gevalCoreWithoutInput (Right . id) (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCore' MAP _ = gevalCoreWithoutInput (Right . DLS.splitOn "\t" . unpack) gevalCoreOnSources MAP _ = gevalCoreWithoutInput (Right . DLS.splitOn "\t" . unpack)
(Right . DLS.splitOn "\t" . unpack) (Right . DLS.splitOn "\t" . unpack)
(\(e,g) -> calculateMAPForOneResult e g) (\(e,g) -> calculateMAPForOneResult e g)
averageC averageC
id id
noGraph noGraph
gevalCore' (LogLossHashed nbOfBits) _ = helper nbOfBits gevalCoreOnSources (LogLossHashed nbOfBits) _ = helperLogLossHashed nbOfBits id
-- for LogLossHashed we "salt" each hash with the line number
where helper nbOfBits expectedLineSource outLineSource =
gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC negate noGraph (WithoutInput expectedLineSource outLineSource)
-- Unfortunately, we're parsing the distribution twice. We need to
-- tentatively parse the distribution when the line number is unknown
-- (so we just set it to 1)
-- @TODO Fix this.
tentativeParser t = case parseDistribution nbOfBits 1 t of
Right _ -> Right t
Left m -> Left m
gevalCore' CharMatch inputLineSource = helper inputLineSource gevalCoreOnSources CharMatch inputLineSource = helper inputLineSource
where where
helper inputLineSource expectedLineSource outputLineSource = do helper inputLineSource expectedLineSource outputLineSource = do
gevalCoreGeneralized (ParserSpecWithInput justUnpack justUnpack justUnpack) step countAgg (fMeasureOnCounts charMatchBeta) noGraph (WithInput inputLineSource expectedLineSource outputLineSource) gevalCoreGeneralized (ParserSpecWithInput justUnpack justUnpack justUnpack) step countAgg (fMeasureOnCounts charMatchBeta) noGraph (WithInput inputLineSource expectedLineSource outputLineSource)
@ -690,14 +656,14 @@ gevalCore' CharMatch inputLineSource = helper inputLineSource
justUnpack = liftOp (Right . unpack) justUnpack = liftOp (Right . unpack)
gevalCore' BIOF1 _ = gevalCoreWithoutInput parseBioSequenceIntoEntities parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput parseBioSequenceIntoEntities parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCore' BIOF1Labels _ = gevalCoreWithoutInput parseBioSequenceIntoEntitiesWithoutNormalization parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput parseBioSequenceIntoEntitiesWithoutNormalization parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
where parseBioSequenceIntoEntitiesWithoutNormalization s = do where parseBioSequenceIntoEntitiesWithoutNormalization s = do
entities <- parseBioSequenceIntoEntities s entities <- parseBioSequenceIntoEntities s
return $ Prelude.map eraseNormalisation entities return $ Prelude.map eraseNormalisation entities
gevalCore' TokenAccuracy _ = gevalCoreWithoutInput intoTokens gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens
intoTokens intoTokens
countHitsAndTotals countHitsAndTotals
hitsAndTotalsAgg hitsAndTotalsAgg
@ -719,7 +685,7 @@ gevalCore' TokenAccuracy _ = gevalCoreWithoutInput intoTokens
hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0) hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0)
-- only MultiLabel-F1 handled for JSONs for the time being... -- only MultiLabel-F1 handled for JSONs for the time being...
gevalCore' (MultiLabelFMeasure beta) _ = gevalCoreWithoutInputOnItemTargets (Right . intoWords) gevalCoreOnSources (MultiLabelFMeasure beta) _ = gevalCoreWithoutInputOnItemTargets (Right . intoWords)
(Right . getWords) (Right . getWords)
(getCounts (==)) (getCounts (==))
countAgg countAgg
@ -731,7 +697,7 @@ gevalCore' (MultiLabelFMeasure beta) _ = gevalCoreWithoutInputOnItemTargets (Rig
intoWords (RawItemTarget t) = Prelude.map unpack $ Data.Text.words t intoWords (RawItemTarget t) = Prelude.map unpack $ Data.Text.words t
intoWords (PartiallyParsedItemTarget ts) = Prelude.map unpack ts intoWords (PartiallyParsedItemTarget ts) = Prelude.map unpack ts
gevalCore' MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
(Right . parseIntoProbList) (Right . parseIntoProbList)
(uncurry countLogLossOnProbList) (uncurry countLogLossOnProbList)
averageC averageC
@ -740,6 +706,18 @@ gevalCore' MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
where where
intoWords = Right . Data.Text.words intoWords = Right . Data.Text.words
doubleParser = getValue . TR.double
helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource)
where -- Unfortunately, we're parsing the distribution twice. We need to
-- tentatively parse the distribution when the line number is unknown
-- (so we just set it to 1)
-- @TODO Fix this.
tentativeParser t = case parseDistribution nbOfBits 1 t of
Right _ -> Right t
Left m -> Left m
generalizedProbabilisticFMeasure beta parseBareEntities parseEntities = gevalCoreWithoutInput parseBareEntities generalizedProbabilisticFMeasure beta parseBareEntities parseEntities = gevalCoreWithoutInput parseBareEntities
parseEntities parseEntities
getProbabilisticCounts getProbabilisticCounts
@ -770,9 +748,8 @@ gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
LineSource (ResourceT m) -> -- ^ source to read the output LineSource (ResourceT m) -> -- ^ source to read the output
m (MetricOutput) -- ^ metric values for the output against the expected output m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreByCorrelationMeasure correlationFunction = gevalCoreByCorrelationMeasure correlationFunction =
gevalCoreWithoutInput outParser outParser id correlationC finalStep noGraph gevalCoreWithoutInput doubleParser doubleParser id correlationC finalStep noGraph
where outParser = getValue . TR.double where correlationC = CC.foldl (flip (:)) []
correlationC = CC.foldl (flip (:)) []
finalStep pairs = correlationFunction $ V.fromList pairs finalStep pairs = correlationFunction $ V.fromList pairs
parseDistributionWrapper :: Word32 -> Word32 -> Text -> HashedDistribution parseDistributionWrapper :: Word32 -> Word32 -> Text -> HashedDistribution