Start using dependent types

This commit is contained in:
Filip Gralinski 2019-09-21 15:27:57 +02:00 committed by Filip Graliński
parent 6e20e79f5b
commit d5d177bc8d
3 changed files with 139 additions and 89 deletions

View File

@ -50,7 +50,10 @@ module GEval.Core
import Debug.Trace
import Data.Singletons.TH
import GEval.Metric
import GEval.MetricsMechanics
import GEval.EvaluationScheme
import Data.Conduit
@ -509,12 +512,12 @@ gevalCoreOnSources Pearson _ = gevalCoreByCorrelationMeasure pearson
gevalCoreOnSources Spearman _ = gevalCoreByCorrelationMeasure spearman
gevalCoreOnSources (ProbabilisticMultiLabelFMeasure beta) _ = generalizedProbabilisticFMeasure beta
intoWords
SAProbabilisticMultiLabelFMeasure
(Right . (\(ProbList es) -> es) . parseIntoProbList)
where intoWords = Right . Data.Text.words
gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilisticFMeasure beta
parseAnnotations
SAProbabilisticSoftFMeasure
parseObtainedAnnotations
-- and now more typical metrics, which:
@ -524,9 +527,9 @@ gevalCoreOnSources (ProbabilisticSoftFMeasure beta) _ = generalizedProbabilistic
-- 4) aggregate the results
-- 5) apply some final funtion on the aggregate
-- 6) create a graph using the aggregate (applicable only to some metrics)
gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC logLossToLikehood noGraph
gevalCoreOnSources Likelihood _ = gevalCoreWithoutInput SALikelihood doubleParser itemLogLossError averageC logLossToLikehood noGraph
gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords
gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput SAMultiLabelLikelihood
(Right . parseIntoProbList)
(uncurry countLogLossOnProbList)
averageC
@ -535,18 +538,18 @@ gevalCoreOnSources MultiLabelLikelihood _ = gevalCoreWithoutInput intoWords
where
intoWords = Right . Data.Text.words
gevalCoreOnSources MSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC id noGraph
gevalCoreOnSources MSE _ = gevalCoreWithoutInput SAMSE doubleParser itemSquaredError averageC id noGraph
gevalCoreOnSources RMSE _ = gevalCoreWithoutInput doubleParser doubleParser itemSquaredError averageC (** 0.5) noGraph
gevalCoreOnSources RMSE _ = gevalCoreWithoutInput SARMSE doubleParser itemSquaredError averageC (** 0.5) noGraph
gevalCoreOnSources MAE _ = gevalCoreWithoutInput doubleParser doubleParser itemAbsoluteError averageC id noGraph
gevalCoreOnSources MAE _ = gevalCoreWithoutInput SAMAE doubleParser itemAbsoluteError averageC id noGraph
gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput doubleParser doubleParser smape averageC (* 100.0) noGraph
gevalCoreOnSources SMAPE _ = gevalCoreWithoutInput SASMAPE doubleParser smape averageC (* 100.0) noGraph
where smape (exp, out) = (abs (exp-out)) `safeDoubleDiv` ((abs exp) + (abs out))
gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput doubleParser doubleParser itemLogLossError averageC id noGraph
gevalCoreOnSources LogLoss _ = gevalCoreWithoutInput SALogLoss doubleParser itemLogLossError averageC id noGraph
gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
gevalCoreOnSources BLEU _ = gevalCoreWithoutInput SABLEU (Right . Prelude.words . unpack) bleuCombine bleuAgg bleuFinal noGraph
where bleuFinal (p1, p2, p3, p4, rl, l1, l2, l3, l4) = ((p1 /. l1) * (p2 /. l2) * (p3 /. l3) * (p4 /. l4)) ** 0.25 * (brevityPenalty l1 rl)
bleuCombine (refs, sen) = bleuStep refs sen
bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0)
@ -556,15 +559,15 @@ gevalCoreOnSources BLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.w
| c == 0 && r > 0 = 0.0
| otherwise = exp (1.0 - (r /. c))
gevalCoreOnSources GLEU _ = gevalCoreWithoutInput (Right . Prelude.map Prelude.words . DLS.splitOn "\t" . unpack) (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph
gevalCoreOnSources GLEU _ = gevalCoreWithoutInput SAGLEU (Right . Prelude.words . unpack) gleuCombine gleuAgg gleuFinal noGraph
where gleuFinal (m, t) = m /. t
gleuCombine (refs, sen) = gleuStep refs sen
gleuAgg = CC.foldl gleuFuse (0, 0)
gleuFuse (a1, a2) (b1, b2) = (a1+b1, a2+b2)
gevalCoreOnSources WER _ = gevalCoreWithoutInput (Right . Prelude.words . unpack) (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph
gevalCoreOnSources WER _ = gevalCoreWithoutInput SAWER (Right . Prelude.words . unpack) (uncurry werStep) averageC id noGraph
gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . strip) hitOrMiss averageC id noGraph
gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput SAAccuracy (Right . strip) hitOrMiss averageC id noGraph
where hitOrMiss (exp, got) =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
@ -589,12 +592,8 @@ gevalCoreOnSources Accuracy _ = gevalCoreWithoutInput (Right . strip) (Right . s
tryReadingAsFloat :: Text -> Maybe Float
tryReadingAsFloat = readMaybe . unpack
gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser getCount countAgg (fMeasureOnCounts beta) noGraph
gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput SAFMeasure outParser getCount countAgg (fMeasureOnCounts beta) noGraph
where outParser = detected <=< (getValue . TR.double)
expParser = expected <=< (getValue . TR.decimal)
expected 1 = Right True
expected 0 = Right False
expected _ = Left "expected 0 or 1"
-- output value could be a probability (for compatibility with other measures)
detected prob
| prob >= 0.0 && prob < detectionThreshold = Right False
@ -607,7 +606,7 @@ gevalCoreOnSources (FMeasure beta) _ = gevalCoreWithoutInput outParser outParser
getCount (False, True) = (0, 0, 1)
getCount (False, False) = (0, 0, 0)
gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just . strip) (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph
gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput SAMacroFMeasure (Right . predicted . strip) getClassesInvolved gatherClassC macroAverageOnCounts noGraph
where predicted got =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
@ -636,7 +635,7 @@ gevalCoreOnSources (MacroFMeasure beta) _ = gevalCoreWithoutInput (Right . Just
M.lookupDefault 0 c gotMap))
$ M.keys expectedMap) / (fromIntegral $ Prelude.length $ M.keys expectedMap)
gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotations
gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput SASoftFMeasure
parseObtainedAnnotations
getSoftCounts
countAgg
@ -646,7 +645,7 @@ gevalCoreOnSources (SoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnotation
Prelude.length expected,
Prelude.length got)
gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledClippings
gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput SASoft2DFMeasure
parseLabeledClippings
count2DFScore
averageC
@ -659,30 +658,29 @@ gevalCoreOnSources (Soft2DFMeasure beta) _ = gevalCoreWithoutInput parseLabeledC
expArea = totalArea expected
gotArea = totalArea got
gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput parseClippingSpecs parseClippings matchStep clippeuAgg finalStep noGraph
gevalCoreOnSources ClippEU _ = gevalCoreWithoutInput SAClippEU parseClippings matchStep clippeuAgg finalStep noGraph
where
parseClippings = controlledParse lineClippingsParser
parseClippingSpecs = controlledParse lineClippingSpecsParser
matchStep (clippingSpecs, clippings) = (maxMatch matchClippingToSpec clippingSpecs clippings,
Prelude.length clippingSpecs,
Prelude.length clippings)
clippeuAgg = CC.foldl countFolder (0, 0, 0)
finalStep counts = f2MeasureOnCounts counts
gevalCoreOnSources NMI _ = gevalCoreWithoutInput (Right . id) (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCoreOnSources NMI _ = gevalCoreWithoutInput SANMI (Right . id) id (CC.foldl updateConfusionMatrix M.empty) normalizedMutualInformationFromConfusionMatrix noGraph
gevalCoreOnSources MAP _ = gevalCoreWithoutInput (Right . DLS.splitOn "\t" . unpack)
gevalCoreOnSources MAP _ = gevalCoreWithoutInput SAMAP
(Right . DLS.splitOn "\t" . unpack)
(\(e,g) -> calculateMAPForOneResult e g)
averageC
id
noGraph
gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput parseBioSequenceIntoEntities parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1 _ = gevalCoreWithoutInput SABIOF1 parseBioSequenceIntoEntities (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput parseBioSequenceIntoEntitiesWithoutNormalization parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources BIOF1Labels _ = gevalCoreWithoutInput SABIOF1Labels parseBioSequenceIntoEntitiesWithoutNormalization (uncurry gatherCountsForBIO) countAgg f1MeasureOnCounts noGraph
gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens
gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput SATokenAccuracy
intoTokens
countHitsAndTotals
hitsAndTotalsAgg
@ -703,7 +701,7 @@ gevalCoreOnSources TokenAccuracy _ = gevalCoreWithoutInput intoTokens
| otherwise = (h, t + 1)
hitsAndTotalsAgg = CC.foldl (\(h1, t1) (h2, t2) -> (h1 + h2, t1 + t2)) (0, 0)
gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput SAMultiLabelLogLoss
(Right . parseIntoProbList)
(uncurry countLogLossOnProbList)
averageC
@ -712,8 +710,6 @@ gevalCoreOnSources MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
where
intoWords = Right . Data.Text.words
doubleParser = getValue . TR.double
helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
gevalCore''' (ParserSpecWithoutInput (liftOp (Right . id)) (liftOp tentativeParser)) (\(lineNo, (t,d)) -> calculateLogLoss nbOfBits lineNo t (parseDistributionWrapper nbOfBits lineNo d)) averageC (finalStep . negate) noGraph (WithoutInput expectedLineSource outLineSource)
where -- Unfortunately, we're parsing the distribution twice. We need to
@ -724,7 +720,7 @@ helperLogLossHashed nbOfBits finalStep expectedLineSource outLineSource =
Right _ -> Right t
Left m -> Left m
generalizedProbabilisticFMeasure beta parseBareEntities parseEntities = gevalCoreWithoutInput parseBareEntities
generalizedProbabilisticFMeasure beta metric parseEntities = gevalCoreWithoutInput metric
parseEntities
getProbabilisticCounts
probabilisticSoftAgg
@ -754,7 +750,7 @@ gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
LineSource (ResourceT m) -> -- ^ source to read the output
m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreByCorrelationMeasure correlationFunction =
gevalCoreWithoutInput doubleParser doubleParser id correlationC finalStep noGraph
gevalCoreWithoutInput SAPearson doubleParser id correlationC finalStep noGraph
where correlationC = CC.foldl (flip (:)) []
finalStep pairs = correlationFunction $ V.fromList pairs
@ -770,9 +766,9 @@ skipLineNumber fun = fun . snd
-- | A helper function to run evaluation when the input is not needed to calculate the metric value.
gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> (Text -> Either String a) -- ^ parser for values in the expected output
=> SAMetric t
-> (Text -> Either String b) -- ^ parser for values in the actual output
-> ((a, b) -> c) -- ^ function which combines parsed values into a single value
-> ((ParsedExpectedType t, b) -> c) -- ^ function which combines parsed values into a single value
-- (will be launched for each item, e.g. an error/cost function
-- could be calculated here)
-> (ConduitT c Void (ResourceT m) d) -- ^ a Conduit which aggregates all the combined values into
@ -782,8 +778,9 @@ gevalCoreWithoutInput :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
-> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricOutput) -- ^ metric values for the output against the expected output
gevalCoreWithoutInput expParser outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream =
gevalCoreWithoutInput smetric outParser itemStep aggregator finalStep generateGraph expectedLineStream outLineStream =
gevalCoreWithoutInputOnItemTargets (liftOp expParser) (liftOp outParser) itemStep aggregator finalStep generateGraph expectedLineStream outLineStream
where expParser = expectedParser smetric
gevalCoreWithoutInputOnItemTargets :: (MonadUnliftIO m, MonadThrow m, MonadIO m)
=> (ItemTarget -> Either String a) -- ^ parser for values in the expected output
@ -961,15 +958,3 @@ itemLogLossError (exp, out)
| v >= 1.0 = 1.0
| v <= 0.0 = 0.0
| otherwise = v
getValue :: Num a => Either String (a, Text) -> Either String a
getValue (Right (x, reminder)) =
if Data.Text.null reminder || Data.Text.head reminder == '\t'
then Right x
else Left "number expected"
getValue (Left s) = Left s
controlledParse parser t =
case parseOnly parser t of
(Right v) -> Right v
(Left _) -> Left "cannot parse line"

View File

@ -4,6 +4,7 @@
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE EmptyCase #-}
{-# LANGUAGE UndecidableInstances #-}
module GEval.MetricsMechanics
where
@ -16,12 +17,15 @@ import Data.Text
import Data.Text.Read as TR
import qualified Data.List.Split as DLS
import Data.Attoparsec.Text (parseOnly)
import Data.Maybe (catMaybes)
import Control.Monad ((<=<))
import GEval.Annotation (Annotation, parseAnnotations)
import GEval.Clippings (ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser)
import GEval.Annotation (Annotation, ObtainedAnnotation, parseAnnotations, parseObtainedAnnotations)
import GEval.Clippings (Clipping, ClippingSpec, LabeledClipping, lineClippingsParser, lineClippingSpecsParser, lineLabeledClippingsParser)
import GEval.BIO (TaggedEntity, parseBioSequenceIntoEntities, parseBioSequenceIntoEntitiesWithoutNormalization)
import GEval.LogLossHashed (parseWordSpecs, wordSpecToPair)
import GEval.ProbList (ProbList(..), parseIntoProbList, WordWithProb(..))
-- | Helper type so that singleton can be used.
-- | (The problem is that some metrics are parametrized by Double
@ -120,7 +124,7 @@ expectedParser SASoftFMeasure = parseAnnotations
expectedParser SAProbabilisticMultiLabelFMeasure = intoWords
expectedParser SAProbabilisticSoftFMeasure = parseAnnotations
expectedParser SASoft2DFMeasure = controlledParse lineLabeledClippingsParser
expectedParser SANMI = onlyStrip
expectedParser SANMI = Right . id
expectedParser SALogLossHashed = onlyStrip
expectedParser SALikelihoodHashed = onlyStrip
expectedParser SACharMatch = Right
@ -136,6 +140,49 @@ expectedParser SAMultiLabelFMeasure = intoWords
expectedParser SAMultiLabelLogLoss = intoWords
expectedParser SAMultiLabelLikelihood = intoWords
type family ParsedOutputType (t :: AMetric) :: * where
ParsedOutputType ABLEU = [Text]
ParsedOutputType AGLEU = [Text]
ParsedOutputType AClippEU = [Clipping]
ParsedOutputType AMacroFMeasure = Maybe Text
ParsedOutputType ASoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticSoftFMeasure = [ObtainedAnnotation]
ParsedOutputType AProbabilisticMultiLabelFMeasure = [WordWithProb]
ParsedOutputType t = ParsedExpectedType t
outputParser :: SAMetric t -> Text -> Either String (ParsedOutputType t)
outputParser SARMSE = expectedParser SARMSE
outputParser SAMSE = expectedParser SARMSE
outputParser SAPearson = expectedParser SAPearson
outputParser SASpearman = expectedParser SASpearman
outputParser SABLEU = intoWords
outputParser SAGLEU = intoWords
outputParser SAWER = expectedParser SAWER
outputParser SAAccuracy = expectedParser SAAccuracy
outputParser SAClippEU = controlledParse lineClippingsParser
outputParser SAFMeasure = probToZeroOneParser
outputParser SAMacroFMeasure = Right . predictedParser . strip
outputParser SASoftFMeasure = parseObtainedAnnotations
outputParser SAProbabilisticMultiLabelFMeasure = (Right . (\(ProbList es) -> es) . parseIntoProbList)
outputParser SAProbabilisticSoftFMeasure = parseObtainedAnnotations
outputParser SASoft2DFMeasure = expectedParser SASoft2DFMeasure
outputParser SANMI = expectedParser SANMI
outputParser SALogLossHashed = onlyStrip
outputParser SALikelihoodHashed = onlyStrip
outputParser SACharMatch = Right
outputParser SAMAP = splitByTabs
outputParser SALogLoss = doubleParser
outputParser SALikelihood = doubleParser
outputParser SABIOF1 = parseBioSequenceIntoEntities
outputParser SABIOF1Labels = parseBioSequenceIntoEntitiesWithoutNormalization
outputParser SATokenAccuracy = intoWords
outputParser SAMAE = doubleParser
outputParser SASMAPE = doubleParser
outputParser SAMultiLabelFMeasure = intoWords
outputParser SAMultiLabelLogLoss = intoWords
outputParser SAMultiLabelLikelihood = intoWords
doubleParser = getValue . TR.double
intoWords = Right . Data.Text.words
@ -148,6 +195,16 @@ onlyStrip = Right . strip
justStrip = Right . Just . strip
predictedParser got =
-- first try to parse what we got as a probability distribution
-- (like the one used for Likelikehood/LogLossHashed metric)
case parseWordSpecs got of
Right wordSpecs -> if Prelude.null pairs
then Nothing
else Just $ snd $ Prelude.maximum pairs
where pairs = catMaybes $ Prelude.map wordSpecToPair wordSpecs
Left _ -> Just got
splitByTabs = Right . DLS.splitOn "\t" . unpack
zeroOneParser = expected <=< (getValue . TR.decimal)
@ -155,6 +212,14 @@ zeroOneParser = expected <=< (getValue . TR.decimal)
expected 0 = Right False
expected _ = Left "expected 0 or 1"
probToZeroOneParser = detected <=< (getValue . TR.double)
where -- output value could be a probability (for compatibility with other measures)
detected prob
| prob >= 0.0 && prob < detectionThreshold = Right False
| prob >= detectionThreshold && prob <= 1.0 = Right True
| otherwise = Left "expected probability"
detectionThreshold = 0.5
getValue :: Num a => Either String (a, Text) -> Either String a
getValue (Right (x, reminder)) =
if Data.Text.null reminder || Data.Text.head reminder == '\t'

View File

@ -2,7 +2,7 @@
{-# LANGUAGE TypeFamilies #-}
module GEval.ProbList
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..))
(parseIntoProbList, selectByStandardThreshold, countLogLossOnProbList, ProbList(..), WordWithProb(..))
where
import qualified Data.Text as T