diff --git a/geval.cabal b/geval.cabal index d9b0811..6b439dd 100644 --- a/geval.cabal +++ b/geval.cabal @@ -59,6 +59,8 @@ library , GEval.Model , GEval.ModelTraining , GEval.MatchingSpecification + , GEval.Confidence + , Data.Conduit.Utils , Paths_geval build-depends: base >= 4.7 && < 5 , cond diff --git a/src/Data/Conduit/AutoDecompress.hs b/src/Data/Conduit/AutoDecompress.hs index b959d6f..24a0af3 100644 --- a/src/Data/Conduit/AutoDecompress.hs +++ b/src/Data/Conduit/AutoDecompress.hs @@ -7,6 +7,7 @@ module Data.Conduit.AutoDecompress import Data.Conduit import Data.Conduit.Combinators +import Data.Conduit.Utils import Data.ByteString import Data.Conduit.Zlib import Data.Word8 @@ -40,6 +41,3 @@ lookAtMagicNumbers (31, 139) = ungzip lookAtMagicNumbers (66, 90) = BZ.bunzip2 lookAtMagicNumbers (253, 55) = XZ.decompress Nothing lookAtMagicNumbers _ = doNothing - -doNothing :: Monad m => ConduitT a a m () -doNothing = Data.Conduit.Combinators.filter (const True) diff --git a/src/Data/Conduit/Utils.hs b/src/Data/Conduit/Utils.hs new file mode 100644 index 0000000..b5c950c --- /dev/null +++ b/src/Data/Conduit/Utils.hs @@ -0,0 +1,17 @@ + +module Data.Conduit.Utils + (gobbleAndDo, + doNothing) + where + +import Data.Conduit +import Data.Conduit.Combinators + + +doNothing :: Monad m => ConduitT a a m () +doNothing = Data.Conduit.Combinators.filter (const True) + +gobbleAndDo :: Monad m => ([a] -> [b]) -> ConduitT a b m () +gobbleAndDo fun = do + l <- sinkList + yieldMany $ fun l diff --git a/src/GEval/Confidence.hs b/src/GEval/Confidence.hs new file mode 100644 index 0000000..67fd8c6 --- /dev/null +++ b/src/GEval/Confidence.hs @@ -0,0 +1,13 @@ +module GEval.Confidence + (totalLineConfidence) + where + +import Data.Text +import GEval.ProbList +import GEval.Probability + +totalLineConfidence :: Text -> Double +totalLineConfidence t = case parseIntoProbList t of + ProbList [] -> 1.0 + ProbList wordWithProbs -> (product $ Prelude.map (\(WordWithProb _ p) -> getP p) wordWithProbs) ** (1/l) + where l = fromIntegral $ Prelude.length wordWithProbs diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index a13210b..114daf9 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -66,6 +66,7 @@ import GEval.Model (ModelType) import Data.Conduit import Data.Conduit.Combinators as CC +import Data.Conduit.Utils import qualified Data.Conduit.Text as CT import Control.Monad.Trans.Resource import qualified Data.Conduit.List as CL @@ -133,6 +134,7 @@ isInputNeeded (EvaluationScheme _ ops) = hasFiltering ops hasFiltering :: [PreprocessingOperation] -> Bool hasFiltering [] = False hasFiltering ((FeatureFilter _):_) = True +hasFiltering ((TopConfidence _):_) = True hasFiltering (_:ops) = hasFiltering ops -- | Could output be preprocessable @@ -306,7 +308,7 @@ extensionsHandled = ["tsv", "jsonl"] data LineSource m = LineSource (ConduitT () Text m ()) (Text -> ItemTarget) (Text -> Text) SourceSpec Word32 data LineSourcesSpecification m = LineSourcesSpecification { - lineSourcesFilter :: Filter, + lineSourcesFilter :: GeneralizedFilter, lineSourcesInputSource :: LineSource m, lineSourcesExpectedSource :: LineSource m, lineSourcesOutputSource :: LineSource m } @@ -359,7 +361,7 @@ addSchemeSpecifics :: EvaluationScheme -> DataSource -> DataSource addSchemeSpecifics scheme@(EvaluationScheme metric _) dataSource = dataSource { dataSourceChallengeData = (dataSourceChallengeData dataSource) { - challengeDataSourceFilter = getFilterForScheme (challengeDataSourceInHeader $ dataSourceChallengeData dataSource) scheme, + challengeDataSourceFilter = getGeneralizedFilterForScheme (challengeDataSourceInHeader $ dataSourceChallengeData dataSource) scheme, challengeDataSourceOutPreprocess = outPreprocess, challengeDataSourceInPreprocess = inPreprocess }} @@ -1039,12 +1041,14 @@ defineContinuation aggregator finalStep generateGraph = do return $ MetricOutput (SimpleRun $ finalStep v) (generateGraph v) fromSpecificationToWithoutInput :: LineSourcesSpecification (ResourceT m) -> WithoutInput m e o -fromSpecificationToWithoutInput lsSpec = case lineSourcesFilter lsSpec of - NoFilter -> WithoutInput expectedSource outSource - theFilter -> WithoutInputButFiltered theFilter inputSource expectedSource outSource +fromSpecificationToWithoutInput lsSpec = + if hasNoFilter theFilter + then WithoutInput expectedSource outSource + else WithoutInputButFiltered theFilter inputSource expectedSource outSource where expectedSource = lineSourcesExpectedSource lsSpec outSource = lineSourcesOutputSource lsSpec inputSource = lineSourcesInputSource lsSpec + theFilter = (lineSourcesFilter lsSpec) fromSpecificationToWithInput :: LineSourcesSpecification (ResourceT m) -> WithInput m i e o fromSpecificationToWithInput lsSpec = WithInput theFilter inpSource expectedSource outSource @@ -1070,7 +1074,7 @@ class EvaluationContext ctxt m where checkStepM :: ((Word32, ParsedRecord ctxt) -> (ResourceT m) c) -> (Word32, WrappedParsedRecord ctxt) -> (ResourceT m) (Maybe c) data WithoutInput m e o = WithoutInput (LineSource (ResourceT m)) (LineSource (ResourceT m)) - | WithoutInputButFiltered Filter (LineSource (ResourceT m)) (LineSource (ResourceT m)) (LineSource (ResourceT m)) + | WithoutInputButFiltered GeneralizedFilter (LineSource (ResourceT m)) (LineSource (ResourceT m)) (LineSource (ResourceT m)) instance (MonadUnliftIO m, MonadIO m, MonadThrow m) => EvaluationContext (WithoutInput m e o) m where data ParserSpec (WithoutInput m e o) = ParserSpecWithoutInput (ItemTarget -> Either String e) (ItemTarget -> Either String o) @@ -1091,7 +1095,8 @@ instance (MonadUnliftIO m, MonadIO m, MonadThrow m) => EvaluationContext (Withou <*> ZipSource (getZipSource $ (,) <$> ZipSource (sourceItems expectedLineSource) <*> ZipSource (sourceItems outLineSource))) - .| CC.filter (applyFilter theFilter) + .| CC.filter (applyFilter $ generalizedFilterFilter theFilter) + .| topperConduit (generalizedFilterTopper theFilter) .| CC.map (\(TargetRecord _ y z) -> WrappedParsedRecordWithoutInput (applyParser expParser y) (applyParser outParser z)) checkStep _ step (lineNo, WrappedParsedRecordWithoutInput (Got expectedItem) (Got outItem)) = Just $ step (lineNo, ParsedRecordWithoutInput expectedItem outItem) @@ -1110,7 +1115,7 @@ instance (MonadUnliftIO m, MonadIO m, MonadThrow m) => EvaluationContext (Withou -data WithInput m i e o = WithInput Filter (LineSource (ResourceT m)) (LineSource (ResourceT m)) (LineSource (ResourceT m)) +data WithInput m i e o = WithInput GeneralizedFilter (LineSource (ResourceT m)) (LineSource (ResourceT m)) (LineSource (ResourceT m)) getInputFilePath :: WithInput m i e o -> SourceSpec getInputFilePath (WithInput _ (LineSource _ _ _ inputFilePath _) _ _) = inputFilePath @@ -1128,7 +1133,8 @@ instance (MonadUnliftIO m, MonadIO m, MonadThrow m) => EvaluationContext (WithIn <*> ZipSource (getZipSource $ (,) <$> ZipSource (sourceItems expectedLineSource) <*> ZipSource (sourceItems outLineSource))) - .| CC.filter (applyFilter theFilter) + .| CC.filter (applyFilter $ generalizedFilterFilter theFilter) + .| topperConduit (generalizedFilterTopper theFilter) .| CC.map (\(TargetRecord x y z) -> WrappedParsedRecordWithInput (applyParser inpParser x) (applyParser expParser y) (applyParser outParser z)) checkStep _ step (lineNo, WrappedParsedRecordWithInput (Got inputItem) (Got expectedItem) (Got outItem)) = Just $ step (lineNo, ParsedRecordWithInput inputItem expectedItem outItem) checkStep _ _ (lineNo, WrappedParsedRecordWithInput _ _ (Wrong m)) = throw $ UnexpectedData lineNo m @@ -1157,7 +1163,7 @@ threeLineSource (WithInput theFilter inputLineSource expectedLineSource outLineS <*> (ZipSource $ getZipSource $ (,) <$> ZipSource (linesAsItems expectedLineSource) <*> ZipSource (linesAsItems outLineSource))) - .| (CC.filter (applyFilterToSourceItems theFilter)) + .| (CC.filter (applyFilterToSourceItems $ generalizedFilterFilter theFilter)) .| (CC.map (\(x, (y,z)) -> WrappedParsedRecordWithInput x y z)) averageC :: MonadResource m => ConduitT Double Void m Double @@ -1201,3 +1207,15 @@ applyFilterToSourceItems filter (Got x, (Got y, Got z)) = applyFilter filter tar (Got (RawItemTarget y)) (Got (RawItemTarget z)) applyFilterToSourceItems _ special = True + +topperConduit NoTopper = doNothing +topperConduit (Topper percentage scorer) = gobbleAndDo (findTop percentage scorer) + +findTop :: Double -> (TargetRecord -> Double) -> [TargetRecord] -> [TargetRecord] +findTop percentage scorer records = + Prelude.map fst + $ Prelude.take n + $ sortBy (\a b -> snd b `compare` snd a) + $ Prelude.map (\r -> (r, scorer r)) records + where n = ceiling ((percentage * (fromIntegral $ (Prelude.length records) - 1)) / 100.0) + -- -1 due to the special entry Done diff --git a/src/GEval/DataSource.hs b/src/GEval/DataSource.hs index 309aca6..1752c30 100644 --- a/src/GEval/DataSource.hs +++ b/src/GEval/DataSource.hs @@ -5,8 +5,11 @@ module GEval.DataSource DataSource(..), TargetRecord(..), Filter(..), + Topper(..), + GeneralizedFilter(..), noFilter, - getFilterForScheme, + hasNoFilter, + getGeneralizedFilterForScheme, applyFilter) where @@ -15,6 +18,7 @@ import GEval.Selector (ItemTarget(..), TargetRecord(..)) import GEval.FeatureExtractor (getFeatures) import GEval.BlackBoxDebugging import GEval.EvaluationScheme +import GEval.Confidence (totalLineConfidence) import Data.Text @@ -25,19 +29,45 @@ import GEval.Selector import qualified Data.Set as S data Filter = NoFilter | FilterByFeatures (Maybe TabularHeader) (S.Set String) +data Topper = NoTopper | Topper Double (TargetRecord -> Double) -noFilter :: Filter -noFilter = NoFilter +data GeneralizedFilter = GeneralizedFilter { + generalizedFilterFilter :: Filter, + generalizedFilterTopper :: Topper } + +noFilter :: GeneralizedFilter +noFilter = GeneralizedFilter { + generalizedFilterFilter = NoFilter, + generalizedFilterTopper = NoTopper } + +hasNoFilter :: GeneralizedFilter -> Bool +hasNoFilter gfilter = case generalizedFilterFilter gfilter of + NoFilter -> case generalizedFilterTopper gfilter of + NoTopper -> True + _ -> False + _ -> False applyFilter :: Filter -> TargetRecord -> Bool applyFilter NoFilter _ = True applyFilter (FilterByFeatures mInHeader featureSpec) tR = applyFeatureFilter mInHeader featureSpec tR +getGeneralizedFilterForScheme :: Maybe TabularHeader -> EvaluationScheme -> GeneralizedFilter +getGeneralizedFilterForScheme mTabHeader scheme = GeneralizedFilter { + generalizedFilterFilter = getFilterForScheme mTabHeader scheme, + generalizedFilterTopper = getTopperForScheme scheme + } + getFilterForScheme :: Maybe TabularHeader -> EvaluationScheme -> Filter getFilterForScheme mTabHeader (EvaluationScheme _ ops) = case findFilter ops of [] -> NoFilter fs -> FilterByFeatures mTabHeader (S.fromList $ Prelude.map (unpack . fixIndex) fs) +getTopperForScheme :: EvaluationScheme -> Topper +getTopperForScheme (EvaluationScheme _ ops) = case findTopper ops of + [] -> NoTopper + [topper] -> topper + _ -> error "only one topper expected" + fixIndex = replace "[" "<" . replace "]" ">" findFilter :: [PreprocessingOperation] -> [Text] @@ -45,6 +75,18 @@ findFilter [] = [] findFilter ((FeatureFilter f):ops) = (f:(findFilter ops)) findFilter (_:ops) = findFilter ops +findTopper :: [PreprocessingOperation] -> [Topper] +findTopper [] = [] +findTopper ((TopConfidence percentage):ops) = [Topper percentage totalRecordConfidence] +findTopper (_:ops) = findTopper ops + +totalRecordConfidence :: TargetRecord -> Double +totalRecordConfidence (TargetRecord _ _ (Got out)) = case out of + RawItemTarget t -> totalLineConfidence t +totalRecordConfidence (TargetRecord _ _ (Wrong _)) = 9999.0 +totalRecordConfidence (TargetRecord _ _ Done) = -9999 + + applyFeatureFilter :: Maybe TabularHeader -> S.Set String -> TargetRecord -> Bool applyFeatureFilter mInHeader featureSpec tR = featureSpec `S.isSubsetOf` (S.fromList $ Prelude.map show fs) @@ -71,7 +113,7 @@ data ChallengeDataSource = ChallengeDataSource { challengeDataSourceSelector :: Maybe Selector, challengeDataSourceOutPreprocess :: Text -> Text, challengeDataSourceInPreprocess :: Text -> Text, - challengeDataSourceFilter :: Filter, + challengeDataSourceFilter :: GeneralizedFilter, challengeDataSourceInHeader :: Maybe TabularHeader, challengeDataSourceOutHeader :: Maybe TabularHeader, -- whether the data will be shown preprocessed (not only diff --git a/src/GEval/EvaluationScheme.hs b/src/GEval/EvaluationScheme.hs index 4d6e79d..42027ca 100644 --- a/src/GEval/EvaluationScheme.hs +++ b/src/GEval/EvaluationScheme.hs @@ -36,6 +36,7 @@ data PreprocessingOperation = RegexpMatch Regex | SetPriority Int | RegexpSubstition Regex Text | FeatureFilter Text + | TopConfidence Double deriving (Eq) leftParameterBracket :: Char @@ -67,6 +68,7 @@ readOps ('N':theRest) = handleParametrizedOp (SetName . pack) theRest readOps ('P':theRest) = handleParametrizedOp (SetPriority . read) theRest readOps ('s':theRest) = handleParametrizedBinaryOp (\a b -> RegexpSubstition (compile (BSU.fromString a) []) (pack b)) theRest readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest +readOps ('p':theRest) = handleParametrizedOp (TopConfidence . read) theRest -- this is not the right way to do this, but try catch at least unknown flags readOps t@(c:_) = if isLetter c then throw $ UnknownFlags t @@ -101,7 +103,7 @@ parseParameter [] = (Nothing, []) parseParameter t@(fChar:theRest) = if fChar == leftParameterBracket then case break (== rightParameterBracket) theRest of - (s, []) -> throw $ UnknownFlags t + (_, []) -> throw $ UnknownFlags t (param, (_:theRest')) -> (Just param, theRest') else throw $ UnknownFlags t @@ -145,6 +147,7 @@ instance Show PreprocessingOperation where show (SetPriority p) = parametrizedOperation "P" (show p) show (RegexpSubstition (Regex _ regexp) s) = "s" ++ (formatParameter $ BSU.toString regexp) ++ (formatParameter $ unpack s) show (FeatureFilter featureSpec) = parametrizedOperation "f" (unpack featureSpec) + show (TopConfidence percentage) = parametrizedOperation "p" (show percentage) applySubstitution :: Regex -> Text -> Text -> Text applySubstitution r substitution t = @@ -181,3 +184,4 @@ applyPreprocessingOperation (SetName _) = id applyPreprocessingOperation (SetPriority _) = id applyPreprocessingOperation (RegexpSubstition regex substition) = applySubstitution regex substition applyPreprocessingOperation (FeatureFilter _) = id +applyPreprocessingOperation (TopConfidence _) = id diff --git a/src/GEval/LineByLine.hs b/src/GEval/LineByLine.hs index 7a93da4..3d96a23 100644 --- a/src/GEval/LineByLine.hs +++ b/src/GEval/LineByLine.hs @@ -33,7 +33,7 @@ import Text.Tokenizer import System.IO -import Data.Conduit.AutoDecompress (doNothing) +import Data.Conduit.Utils import Data.Conduit import qualified Data.Conduit.List as CL @@ -480,11 +480,6 @@ runLineByLineGeneralized ordering spec consum = do compareScores (LineRecord _ _ _ _ s1) (LineRecord _ _ _ _ s2) = s1 `compare` s2 mReferences = Nothing -gobbleAndDo :: Monad m => ([a] -> [b]) -> ConduitT a b m () -gobbleAndDo fun = do - l <- CC.sinkList - CC.yieldMany $ fun l - runDiff :: ResultOrdering -> Maybe String -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO () runDiff ordering featureFilter otherOut spec bbdo = do mInHeader <- readHeaderFileWrapper $ getInHeader spec diff --git a/test/Spec.hs b/test/Spec.hs index e773a31..67af1a2 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -31,6 +31,7 @@ import GEval.FeatureExtractor import GEval.Selector import GEval.CreateChallenge import GEval.Validation +import GEval.Confidence import Data.Conduit.Bootstrap import Data.Map.Strict @@ -373,6 +374,8 @@ main = hspec $ do runGEvalTest "multilabel-f1-ie-fuzzy-harden" `shouldReturnAlmost` 0.555555555 it "information extraction" $ do runGEvalTest "multilabel-f1-ie-probs" `shouldReturnAlmost` 0.1111111111 + it "with top confident" $ do + runGEvalTest "top-confidence" `shouldReturnAlmost` 0.857142857142857 describe "Mean/MultiLabel-F" $ do it "simple" $ do runGEvalTest "mean-multilabel-f1-simple" `shouldReturnAlmost` 0.5 @@ -801,6 +804,11 @@ main = hspec $ do loess (DVU.fromList [0.2, 0.6, 1.0]) (DVU.fromList [-0.6, 0.2, 1.0]) 0.4 `shouldBeAlmost` (-0.2) + describe "Confidence for a line" $ do + it "simple case" $ do + totalLineConfidence "foo:0.6" `shouldBeAlmost` 0.6 + it "more than one" $ do + totalLineConfidence "foo:1.0 bar:0.6 bar:0.5" `shouldBeAlmost` 0.669432952768766 describe "Calibration" $ do it "empty list" $ do calibration [] [] `shouldBeAlmost` 1.0 diff --git a/test/top-confidence/top-confidence-solution/test-A/out.tsv b/test/top-confidence/top-confidence-solution/test-A/out.tsv new file mode 100644 index 0000000..36d3962 --- /dev/null +++ b/test/top-confidence/top-confidence-solution/test-A/out.tsv @@ -0,0 +1,5 @@ +foo:0.96 qqq:0.001 bar:0.95 +foo:0.7 baz:0.6 +bar:0.7 +tok:0.9999 +baz:0.8 diff --git a/test/top-confidence/top-confidence/config.txt b/test/top-confidence/top-confidence/config.txt new file mode 100644 index 0000000..3d73c4d --- /dev/null +++ b/test/top-confidence/top-confidence/config.txt @@ -0,0 +1 @@ +--metric MultiLabel-F1:fs<>p<50> diff --git a/test/top-confidence/top-confidence/test-A/expected.tsv b/test/top-confidence/top-confidence/test-A/expected.tsv new file mode 100644 index 0000000..90d2e81 --- /dev/null +++ b/test/top-confidence/top-confidence/test-A/expected.tsv @@ -0,0 +1,5 @@ +foo bar baq +foo baz baq baq baq baq baq +bat +kwa kwa kwa +baz diff --git a/test/top-confidence/top-confidence/test-A/in.tsv b/test/top-confidence/top-confidence/test-A/in.tsv new file mode 100644 index 0000000..af2fd17 --- /dev/null +++ b/test/top-confidence/top-confidence/test-A/in.tsv @@ -0,0 +1,5 @@ +yep +yep +yep +no +yep