OutputParser handled by dependent types

This commit is contained in:
Filip Graliński 2019-11-02 09:31:16 +01:00
parent d5d177bc8d
commit 32969bb56a
2 changed files with 30 additions and 46 deletions

View File

@ -527,10 +527,9 @@ gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilistic
-- 4) aggregate the results
-- 5) apply some final funtion on the aggregate
-- 6) create a graph using the aggregate (applicable only to some metrics)
gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood doubleParser itemLogLossError averageC logLossToLikehood noGraph
gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood itemLogLossError averageC logLossToLikehood noGraph
gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLikelihood
(Right . parseIntoProbList)
(uncurry countLogLossOnProbList)
averageC
logLossToLikehood
@ -538,18 +537,18 @@ gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLi
where
intoWords = Right . Data.Text.words
gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE doubleParser itemSquaredError averageC id noGraph
gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE itemSquaredError averageC id noGraph
gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE doubleParser itemSquaredError averageC (** 0.5) noGraph
gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE itemSquaredError averageC (** 0.5) noGraph
gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE doubleParser itemAbsoluteError averageC id noGraph
gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE itemAbsoluteError averageC id noGraph
gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE doubleParser smape averageC (* 100.0) noGraph
gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE smape averageC (* 100.0) noGraph
where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out))
gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss doubleParser itemLogLossError averageC id noGraph
gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss itemLogLossError averageC id noGraph
gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU bleuCombine bleuAgg bleuFinal noGraph
where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl)
bleuCombine (refs, sen) = bleuStep refs sen
bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0)
@ -559,15 +558,15 @@ gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU (Right . Prelude.words
| c == 0 && r > 0 = 0.0
| otherwise = exp (1.0 - (r /. c))
gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph
gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU gleuCombine gleuAgg gleuFinal noGraph
where gleuFinal (m, t) = m /. t
gleuCombine (refs, sen) = gleuStep refs sen
gleuAgg = CC.foldl gleuFuse (0, 0)
gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2)
gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph
gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER (uncurry werStep) averageC id noGraph
gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy (Right . strip) hitOrMiss averageC id noGraph
gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy hitOrMiss averageC id noGraph
where hitOrMiss (exp, got) =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
@ -592,21 +591,15 @@ gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy (Right . strip)
tryReadingAsFloat :: Text -> Maybe Float
tryReadingAsFloat = readMaybe . unpack
gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure outParser getCount countAgg (fMeasureOnCounts beta) noGraph
where outParser = detected <=< (getValue . TR.double)
-- output value could be a probability (for compatibility with other measures)
detected prob
| prob >= 0.0 && prob < detectionThreshold = Right False
| prob >= detectionThreshold && prob <= 1.0 = Right True
| otherwise = Left "expected probability"
detectionThreshold = 0.5
gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure getCount countAgg (fMeasureOnCounts beta) noGraph
where -- output value could be a probability (for compatibility with other measures)
getCount :: (Bool, Bool) -> (Int, Int, Int)
getCount (True, True) = (1, 1, 1)
getCount (True, False) = (0, 1, 0)
getCount (False, True) = (0, 0, 1)
getCount (False, False) = (0, 0, 0)
gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph
gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure getClassesInvolved gatherClassC macroAverageOnCounts noGraph
where predicted got =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
@ -636,7 +629,6 @@ gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasur
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure
parseObtainedAnnotations
getSoftCounts
countAgg
(fMeasureOnCounts beta)
@ -646,48 +638,42 @@ gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure
Prelude.length got)
gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput SASoft2DFMeasure
parseLabeledClippings
count2DFScore
averageC
id
noGraph
where
parseLabeledClippings = controlledParse lineLabeledClippingsParser
count2DFScore (expected, got) = fMeasureOnCounts beta (tpArea, expArea, gotArea)
where tpArea = coveredBy expected got
expArea = totalArea expected
gotArea = totalArea got
gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU parseClippings matchStep clippeuAgg finalStep noGraph
gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU matchStep clippeuAgg finalStep noGraph
where
parseClippings = controlledParse lineClippingsParser
matchStep (clippingSpecs, clippings) = (maxMatch matchClippingToSpec clippingSpecs clippings,
Prelude.length clippingSpecs,
Prelude.length clippings)
clippeuAgg = CC.foldl countFolder (0, 0, 0)
finalStep counts = f2MeasureOnCounts counts
gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCoreOnSources MAP _ = gevalCoreWithoutInput SAMAP
(Right . DLS.splitOn "\t" . unpack)
(\(e,g) -> calculateMAPForOneResult e g)
averageC
id
noGraph
gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy
intoTokens
countHitsAndTotals
hitsAndTotalsAgg
(\(hits, total) -> hits /. total)
noGraph
where intoTokens = Right . Data.Text.words
countHitsAndTotals :: ([Text], [Text]) -> (Int, Int)
where countHitsAndTotals :: ([Text], [Text]) -> (Int, Int)
countHitsAndTotals (es, os) =
if Prelude.length os /= Prelude.length es
then throw $ OtherException "wrong number of tokens"
@ -702,13 +688,10 @@ gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy
hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0)
gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput SAMultiLabelLogLoss
(Right . parseIntoProbList)
(uncurry countLogLossOnProbList)
averageC
id
noGraph
where
intoWords = Right . Data.Text.words
helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource)
@ -721,7 +704,6 @@ helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
Left m -> Left m
generalizedProbabilisticFMeasure beta metric parseEntities = gevalCoreWithoutInput metric
parseEntities
getProbabilisticCounts
probabilisticSoftAgg
(fMeasureOnProbabilisticCounts beta)
@ -750,7 +732,7 @@ gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
LineSource (ResourceT m) -> -- ^ source to read the output
m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreByCorrelationMeasure correlationFunction =
gevalCoreWithoutInput SAPearson doubleParser id correlationC finalStep noGraph
gevalCoreWithoutInput SAPearson id correlationC finalStep noGraph
where correlationC = CC.foldl (flip (:)) []
finalStep pairs = correlationFunction $ V.fromList pairs
@ -767,8 +749,7 @@ skipLineNumber fun = fun . snd
-- | A helper function to run evaluation when the input is not needed to calculate the metric value.
gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> SAMetric t
-> (Text -> Either String b) -- ^ parser for values in the actual output
-> ((ParsedExpectedType t, b) -> c) -- ^ function which combines parsed values into a single value
-> ((ParsedExpectedType t, ParsedOutputType t) -> c) -- ^ function which combines parsed values into a single value
-- (will be launched for each item, e.g. an error/cost function
-- could be calculated here)
-> (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into
@ -778,9 +759,10 @@ gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
-> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreWithoutInput smetric outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream =
gevalCoreWithoutInput smetric itemStep aggregator finalStep generateGraph expectedLineStream outLineStream =
gevalCoreWithoutInputOnItemTargets (liftOp expParser) (liftOp outParser) itemStep aggregator finalStep generateGraph expectedLineStream outLineStream
where expParser = expectedParser smetric
outParser = outputParser smetric
gevalCoreWithoutInputOnItemTargets :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> (ItemTarget -> Either String a) -- ^ parser for values in the expected output

View File

@ -141,13 +141,15 @@ expectedParser SAMultiLabelLogLoss = intoWords
expectedParser SAMultiLabelLikelihood = intoWords
type family ParsedOutputType (t :: AMetric) :: * where
ParsedOutputType ABLEU = [Text]
ParsedOutputType AGLEU = [Text]
ParsedOutputType ABLEU = [String]
ParsedOutputType AGLEU = [String]
ParsedOutputType AClippEU = [Clipping]
ParsedOutputType AMacroFMeasure = Maybe Text
ParsedOutputType ASoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb]
ParsedOutputType AMultiLabelLikelihood = ProbList
ParsedOutputType AMultiLabelLogLoss = ProbList
ParsedOutputType t = ParsedExpectedType t
outputParser :: SAMetric t -> Text -> Either String (ParsedOutputType t)
@ -155,8 +157,8 @@ outputParser SARMSE = expectedParser SARMSE
outputParser SAMSE = expectedParser SARMSE
outputParser SAPearson = expectedParser SAPearson
outputParser SASpearman = expectedParser SASpearman
outputParser SABLEU = intoWords
outputParser SAGLEU = intoWords
outputParser SABLEU = Right . Prelude.words . unpack
outputParser SAGLEU = Right . Prelude.words . unpack
outputParser SAWER = expectedParser SAWER
outputParser SAAccuracy = expectedParser SAAccuracy
outputParser SAClippEU = controlledParse lineClippingsParser
@ -179,8 +181,8 @@ outputParser SATokenAccuracy = intoWords
outputParser SAMAE = doubleParser
outputParser SASMAPE = doubleParser
outputParser SAMultiLabelFMeasure = intoWords
outputParser SAMultiLabelLogLoss = intoWords
outputParser SAMultiLabelLikelihood = intoWords
outputParser SAMultiLabelLogLoss = Right . parseIntoProbList
outputParser SAMultiLabelLikelihood = Right . parseIntoProbList
doubleParser = getValue . TR.double