diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index b7cdf0f..6f1cc36 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -471,7 +471,6 @@ gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta SAProbabilisticMultiLabelFMeasure - where intoWords = Right . Data.Text.words gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta SAProbabilisticSoftFMeasure @@ -489,8 +488,6 @@ gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLi averageC logLossToLikehood noGraph - where - intoWords = Right . Data.Text.words gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE averageC id noGraph @@ -695,6 +692,112 @@ gevalCoreGeneralized' parserSpec itemStep aggregator finalStep generateGraph con <*> (ZipSource $ recordSource context parserSpec)) .| CL.map (checkStep (Proxy :: Proxy m) itemStep)) .| CL.catMaybes .| aggregator) return $ MetricOutput (SimpleRun $ finalStep v) (generateGraph v) + +continueGEvalCalculations :: (MonadIO m) => + SAMetric t + -> Metric + -> ConduitT (ItemIntermediateRepresentationType t) Void (ResourceT m) MetricOutput + +continueGEvalCalculations SALikelihood Likelihood = defineContinuation averageC logLossToLikehood noGraph + +continueGEvalCalculations SAMultiLabelLikelihood MultiLabelLikelihood = defineContinuation averageC + logLossToLikehood + noGraph + +continueGEvalCalculations SAMSE MSE = defineContinuation averageC id noGraph + +continueGEvalCalculations SARMSE RMSE = defineContinuation averageC (** 0.5) noGraph + +continueGEvalCalculations SAMAE MAE = defineContinuation averageC id noGraph + +continueGEvalCalculations SASMAPE SMAPE = defineContinuation averageC (* 100.0) noGraph + +continueGEvalCalculations SALogLoss LogLoss = defineContinuation averageC id noGraph + +continueGEvalCalculations SABLEU BLEU = defineContinuation bleuAgg bleuFinal noGraph + where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl) + bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) + bleuFuse (a1, a2, a3, a4, a5, a6, a7, a8, a9) (b1, b2, b3, b4, b5, b6, b7, b8, b9) = (a1+b1, a2+b2, a3+b3, a4+b4, a5+b5, a6+b6, a7+b7, a8+b8, a9+b9) + brevityPenalty c r + | c >= r = 1.0 + | c == 0 && r > 0 = 0.0 + | otherwise = exp (1.0 - (r /. c)) + +continueGEvalCalculations GLEU _ = defineContinuation SAGLEU gleuAgg gleuFinal noGraph + where gleuFinal (m, t) = m /. t + gleuAgg = CC.foldl gleuFuse (0, 0) + gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2) + +continueGEvalCalculations WER _ = defineContinuation SAWER averageC id noGraph + +continueGEvalCalculations Accuracy _ = defineContinuation SAAccuracy averageC id noGraph + +continueGEvalCalculations (FMeasure beta) _ = defineContinuation SAFMeasure countAgg (fMeasureOnCounts beta) noGraph + +continueGEvalCalculations (MacroFMeasure beta) _ = defineContinuation SAMacroFMeasure gatherClassC macroAverageOnCounts noGraph + where gatherClassC = CC.foldl gatherClassCombiner (M.empty, M.empty, M.empty) + gatherClassCombiner (tpMap, expectedMap, gotMap) (tp, expected, got) = + (insertMaybeToMap tp tpMap, + insertMaybeToMap expected expectedMap, + insertMaybeToMap got gotMap) + insertMaybeToMap Nothing m = m + insertMaybeToMap (Just c) m = M.insertWith (+) c 1 m + macroAverageOnCounts (tpMap, expectedMap, gotMap) = + (Prelude.sum + $ Prelude.map (\c -> fMeasureOnCounts beta (M.lookupDefault (0::Int) c tpMap, + M.lookupDefault 0 c expectedMap, + M.lookupDefault 0 c gotMap)) + $ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap) + +continueGEvalCalculations (SoftFMeasure beta) _ = defineContinuation SASoftFMeasure + countAgg + (fMeasureOnCounts beta) + noGraph + +continueGEvalCalculations (Soft2DFMeasure beta) _ = defineContinuation SASoft2DFMeasure + (CC.map (fMeasureOnCounts beta) .| averageC) + id + noGraph + +continueGEvalCalculations ClippEU _ = defineContinuation SAClippEU clippeuAgg finalStep noGraph + where + clippeuAgg = CC.foldl countFolder (0, 0, 0) + finalStep counts = f2MeasureOnCounts counts + +continueGEvalCalculations NMI _ = defineContinuation SANMI (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph + +continueGEvalCalculations MAP _ = defineContinuation SAMAP + averageC + id + noGraph + +continueGEvalCalculations BIOF1 _ = defineContinuation SABIOF1 countAgg f1MeasureOnCounts noGraph + +continueGEvalCalculations BIOF1Labels _ = defineContinuation SABIOF1Labels countAgg f1MeasureOnCounts noGraph + +continueGEvalCalculations TokenAccuracy _ = defineContinuation SATokenAccuracy + hitsAndTotalsAgg + (\(hits, total) -> hits /. total) + noGraph + where + hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0) + +continueGEvalCalculations MultiLabelLogLoss _ = defineContinuation SAMultiLabelLogLoss + averageC + id + noGraph + + + +defineContinuation :: (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into + -- a "total" value + -> (d -> Double) -- ^ function to transform the "total" value into the final score + -> (d -> Maybe GraphSeries) + -> ConduitT c Void (ResourceT m) MetricOutput +defineContinuation aggregator finalStep generateGraph = do + v <- aggregator + return $ MetricOutput (SimpleRun $ finalStep v) (generateGraph v) + -- | A type family to handle all the evaluation "context". -- -- This is needed as for some metrics the output and the expected metric is enough