From bfcd5aa631ed3609b775152ea4f6d4ad1fef8f05 Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Sat, 25 Jan 2020 19:26:57 +0100 Subject: [PATCH] Most evaluation metrics are handled with dependency types --- src/GEval/Core.hs | 114 ++++++++-------------------------------------- 1 file changed, 19 insertions(+), 95 deletions(-) diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index c68240d..0ca1489 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -522,103 +522,25 @@ gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilistic -- 4) aggregate the results -- 5) apply some final funtion on the aggregate -- 6) create a graph using the aggregate (applicable only to some metrics) -gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood averageC logLossToLikehood noGraph -gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLikelihood - averageC - logLossToLikehood - noGraph +gevalCoreOnSources metric _ = gevalCoreOnSourcesStandardWay metric -gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE averageC id noGraph - -gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE averageC (** 0.5) noGraph - -gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE averageC id noGraph - -gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE averageC (* 100.0) noGraph - -gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss averageC id noGraph - -gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU bleuAgg bleuFinal noGraph - where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl) - bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) - bleuFuse (a1, a2, a3, a4, a5, a6, a7, a8, a9) (b1, b2, b3, b4, b5, b6, b7, b8, b9) = (a1+b1, a2+b2, a3+b3, a4+b4, a5+b5, a6+b6, a7+b7, a8+b8, a9+b9) - brevityPenalty c r - | c >= r = 1.0 - | c == 0 && r > 0 = 0.0 - | otherwise = exp (1.0 - (r /. c)) - -gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU gleuAgg gleuFinal noGraph - where gleuFinal (m, t) = m /. t - gleuAgg = CC.foldl gleuFuse (0, 0) - gleuFuse (a1, a2) (b1, b2) = (a1 + b1, a2 + b2) - -gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER werAgg werFinal noGraph - where werAgg = CC.foldl werFuse (0, 0) - werFuse (a1, a2) (b1, b2) = (a1 + b1, a2 + b2) - werFinal (errors, ref) = errors /. ref - -gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy averageC id noGraph - -gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure countAgg (fMeasureOnCounts beta) noGraph - -gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure gatherClassC macroAverageOnCounts noGraph - where gatherClassC = CC.foldl gatherClassCombiner (M.empty, M.empty, M.empty) - gatherClassCombiner (tpMap, expectedMap, gotMap) (tp, expected, got) = - (insertMaybeToMap tp tpMap, - insertMaybeToMap expected expectedMap, - insertMaybeToMap got gotMap) - insertMaybeToMap Nothing m = m - insertMaybeToMap (Just c) m = M.insertWith (+) c 1 m - macroAverageOnCounts (tpMap, expectedMap, gotMap) = - (Prelude.sum - $ Prelude.map (\c -> fMeasureOnCounts beta (M.lookupDefault (0::Int) c tpMap, - M.lookupDefault 0 c expectedMap, - M.lookupDefault 0 c gotMap)) - $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) - -gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure - countAgg - (fMeasureOnCounts beta) - noGraph - -gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput SASoft2DFMeasure - (CC.map (fMeasureOnCounts beta) .| averageC) - id - noGraph - -gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU clippeuAgg finalStep noGraph - where - clippeuAgg = CC.foldl countFolder (0, 0, 0) - finalStep counts = f2MeasureOnCounts counts - -gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph - -gevalCoreOnSources MAP _ = gevalCoreWithoutInput SAMAP - averageC - id - noGraph - -gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 countAgg f1MeasureOnCounts noGraph - -gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels countAgg f1MeasureOnCounts noGraph - -gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy - hitsAndTotalsAgg - (\(hits, total) -> hits /. total) - noGraph - where - hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0) - -gevalCoreOnSources SegmentAccuracy _ = gevalCoreWithoutInput SASegmentAccuracy - averageC - id - noGraph - -gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput SAMultiLabelLogLoss - averageC - id - noGraph +gevalCoreOnSourcesStandardWay :: (MonadIO m, MonadThrow m, MonadUnliftIO m) => + Metric -- ^ evaluation metric + -> LineSource (ResourceT m) -- ^ source to read the expected output + -> LineSource (ResourceT m) -- ^ source to read the output + -> m (MetricOutput) -- ^ metric values for the output against the expected output +gevalCoreOnSourcesStandardWay metric expectedLineStream outLineStream = + case toSing $ toHelper metric of + SomeSing smetric -> gevalRunPipeline parserSpec (trans step) finalPipeline context + where parserSpec = (ParserSpecWithoutInput (liftOp expParser) (liftOp outParser)) + context = (WithoutInput expectedLineStream outLineStream) + step = itemStep smetric + expParser = expectedParser smetric + outParser = outputParser smetric + finalPipeline = continueGEvalCalculations smetric metric + trans :: ((a, b) -> c) -> ParsedRecord (WithoutInput m a b) -> c + trans step (ParsedRecordWithoutInput x y) = step (x, y) helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource = gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource) @@ -843,6 +765,8 @@ continueGEvalCalculations SABIOF1 BIOF1 = defineContinuation countAgg f1MeasureO continueGEvalCalculations SABIOF1Labels BIOF1Labels = defineContinuation countAgg f1MeasureOnCounts noGraph +continueGEvalCalculations SASegmentAccuracy SegmentAccuracy = defineContinuation averageC id noGraph + continueGEvalCalculations SATokenAccuracy TokenAccuracy = defineContinuation hitsAndTotalsAgg (\(hits, total) -> hits /. total) noGraph