diff --git a/geval.cabal b/geval.cabal index c275f2d..d8b8c51 100644 --- a/geval.cabal +++ b/geval.cabal @@ -55,6 +55,8 @@ library , GEval.Formatting , Data.Conduit.Header , GEval.DataSource + , GEval.Model + , GEval.ModelTraining , Paths_geval build-depends: base >= 4.7 && < 5 , cond @@ -111,6 +113,7 @@ library , ordered-containers , random , rainbow + , yaml default-language: Haskell2010 executable geval diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 9d6aa19..74610c1 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -63,6 +63,7 @@ import Data.Singletons.TH import GEval.Metric import GEval.MetricsMechanics import GEval.EvaluationScheme +import GEval.Model (ModelType) import Data.Conduit import Data.Conduit.Combinators as CC @@ -227,6 +228,8 @@ data GEvalSpecialCommand = Init | PrintVersion | JustTokenize | Submit | Validate | ListMetrics | OracleItemBased + | TrainModel ModelType + | Infer FilePath data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest diff --git a/src/GEval/EvaluationScheme.hs b/src/GEval/EvaluationScheme.hs index dd44eb5..135f9d8 100644 --- a/src/GEval/EvaluationScheme.hs +++ b/src/GEval/EvaluationScheme.hs @@ -4,6 +4,7 @@ module GEval.EvaluationScheme applyPreprocessingOperations, evaluationSchemeName, evaluationSchemePriority, + getRegexpMatch, PreprocessingOperation(..)) where @@ -59,6 +60,12 @@ readOps ('s':theRest) = handleParametrizedBinaryOp (\a b -> RegexpSubstition (fr readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest readOps s = ([], s) +getRegexpMatch :: EvaluationScheme -> Maybe Regex +getRegexpMatch (EvaluationScheme _ ops) = getRegexpMatch' ops + where getRegexpMatch' [] = Nothing + getRegexpMatch' ((RegexpMatch regex):_) = Just regex + getRegexpMatch' (_:ops) = getRegexpMatch' ops + handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String) handleParametrizedOp constructor theRest = case parseParameter theRest of diff --git a/src/GEval/LineByLine.hs b/src/GEval/LineByLine.hs index eefcb29..1326444 100644 --- a/src/GEval/LineByLine.hs +++ b/src/GEval/LineByLine.hs @@ -21,7 +21,8 @@ module GEval.LineByLine ResultOrdering(..), justTokenize, worstFeaturesPipeline, - runOracleItemBased + runOracleItemBased, + runMultiOutputGeneralizedForEvaluationScheme ) where import GEval.Core @@ -516,7 +517,11 @@ runOracleItemBased spec = runMultiOutputGeneralized spec consum metric = gesMainMetric spec runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO () -runMultiOutputGeneralized spec consum = do +runMultiOutputGeneralized spec consum = runMultiOutputGeneralizedForEvaluationScheme spec mainScheme consum + where mainScheme = gesMainScheme spec + +runMultiOutputGeneralizedForEvaluationScheme :: GEvalSpecification -> EvaluationScheme -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO () +runMultiOutputGeneralizedForEvaluationScheme spec scheme@(EvaluationScheme metric _) consum = do dataSource' <- checkAndGetDataSource True spec let dataSource = addSchemeSpecifics scheme dataSource' let (Just altOuts) = gesAltOutFiles spec @@ -530,8 +535,6 @@ runMultiOutputGeneralized spec consum = do dataSourceOut = s}) sourceSpecs runResourceT $ runConduit $ (sequenceSources sources .| consum) - where metric = gesMainMetric spec - scheme = gesMainScheme spec runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO () runMostWorseningFeatures ordering otherOut spec bbdo = do diff --git a/src/GEval/Model.hs b/src/GEval/Model.hs new file mode 100644 index 0000000..183a149 --- /dev/null +++ b/src/GEval/Model.hs @@ -0,0 +1,48 @@ +{-# LANGUAGE OverloadedStrings #-} + +module GEval.Model( + ModelType(..), + SimpleMixEnsembleModel(..), + SimpleMixEnsembleModelRule(..)) + where + +import qualified Data.ByteString.UTF8 as BSU +import Text.Regex.PCRE.Heavy +import Text.Regex.PCRE.Light.Base (Regex(..)) + +import qualified Data.Vector as V + +import Data.Either (fromRight) + +import Data.Yaml + +data ModelType = SimpleMixEnsemble + deriving (Show, Read, Eq) + +data SimpleMixEnsembleModel = SimpleMixEnsembleModel [SimpleMixEnsembleModelRule] + deriving (Show) + +data SimpleMixEnsembleModelRule = SimpleMixEnsembleModelRule { + simpleMixEnsembleModelRuleRegex :: Regex, + simpleMixEnsembleModelRuleOut :: String } + deriving (Eq, Show) + +instance ToJSON SimpleMixEnsembleModelRule where + toJSON rule = object ["regex" .= (formatRegexp $ simpleMixEnsembleModelRuleRegex rule), + "out" .= simpleMixEnsembleModelRuleOut rule] + +instance FromJSON SimpleMixEnsembleModelRule where + parseJSON = withObject "SimpleMixEnsembleModelRule" $ \v -> SimpleMixEnsembleModelRule + <$> (toRegexp <$> (v .: "regex")) + <*> v .: "out" + +instance FromJSON SimpleMixEnsembleModel where + parseJSON = withArray "SimpleMixEnsembleModel" $ \arr -> + SimpleMixEnsembleModel <$> mapM parseJSON (V.toList arr) + +instance ToJSON SimpleMixEnsembleModel where + toJSON (SimpleMixEnsembleModel rules) = toJSON rules + +formatRegexp (Regex _ regexp) = BSU.toString regexp + +toRegexp = (fromRight undefined) . ((flip compileM) []) . BSU.fromString diff --git a/src/GEval/ModelTraining.hs b/src/GEval/ModelTraining.hs new file mode 100644 index 0000000..31dc1c7 --- /dev/null +++ b/src/GEval/ModelTraining.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE ScopedTypeVariables #-} + +module GEval.ModelTraining( + trainModel, + infer) + where + +import GEval.Core +import GEval.Model +import GEval.Common + +import System.FilePath +import Data.Either +import Data.Maybe + +import Data.List (maximumBy, transpose) + +import Data.Conduit.Header (processHeader) +import Text.Regex.PCRE.Heavy +import Text.Regex.PCRE.Light.Base (Regex(..)) +import Data.Conduit +import Data.Conduit.AutoDecompress (autoDecompress) +import qualified Data.Conduit.Text as CT +import qualified Data.Conduit.Combinators as CC +import Control.Monad.Trans.Resource +import qualified Data.Text as T +import qualified Data.Map as M + +import qualified Data.ByteString.Lazy as BSL +import qualified Data.ByteString as BS + +import qualified Data.Yaml as Y + +import Data.Conduit.SmartSource +import GEval.EvaluationScheme + +infer :: FilePath -> GEvalSpecification -> IO () +infer modelPath spec = do + model :: SimpleMixEnsembleModel <- Y.decodeFileThrow modelPath + outs <- fromJust <$> checkMultipleOuts spec + outSpecs' <- mapM (getSmartSourceSpec ((gesOutDirectory spec) (gesTestName spec)) "out.tsv") outs + let outSpecs = rights outSpecs' + mHeader <- readHeaderFileWrapper headerSpec + let bareFilesMap = M.fromList $ zip (map (\(FilePathSpec fp) -> takeFileName fp) outSpecs) [0..] + + let sources = map (createSource mHeader) outSpecs + runResourceT $ runConduit $ (sequenceSources sources .| (CC.map (runModel model bareFilesMap)) .| CC.unlines .| CC.encodeUtf8 .| CC.stdout) + where createSource mHeader spec = (smartSource spec) .| autoDecompress .| CT.decodeUtf8Lenient .| CT.lines .| processHeader mHeader + headerSpec = gesOutHeader spec + +runModel :: SimpleMixEnsembleModel -> M.Map FilePath Int -> [T.Text] -> T.Text +runModel (SimpleMixEnsembleModel rules) bareFilesMap ts = T.intercalate (T.pack " ") $ map applyRule rules + where applyRule rule = applyRegex (simpleMixEnsembleModelRuleRegex rule) (ts !! (bareFilesMap M.! (simpleMixEnsembleModelRuleOut rule))) + applyRegex regex = T.concat . (map fst) . (scan regex) + + +trainModel :: ModelType -> GEvalSpecification -> IO () +trainModel SimpleMixEnsemble spec = do + vals' <- geval spec + let vals = map (\(s, val) -> (s, map getMetricValue val)) vals' + + let byMetric = filter (\(s, _) -> isEligibleForSimpleMix s) $ zip (gesMetrics spec) $ transpose $ map (\(_, scores) -> scores) vals + + let sources = map (\(FilePathSpec fp, _) -> takeFileName fp) vals + + let bestOnes = map ((\(scheme, ix) -> SimpleMixEnsembleModelRule { + simpleMixEnsembleModelRuleRegex = fromJust $ getRegexpMatch scheme, + simpleMixEnsembleModelRuleOut = sources !! ix}) . selectBest) byMetric + + let model = SimpleMixEnsembleModel bestOnes + let modelContent = Y.encode model + BS.putStr modelContent + where selectBest (scheme, scores) = (scheme, bestIx) + where bestIx = fst $ maximumBy (\(_, SimpleRun s) (_, SimpleRun s') -> s `compare` s') $ zip [0..] scores + +isEligibleForSimpleMix :: EvaluationScheme -> Bool +isEligibleForSimpleMix scheme = + evaluationSchemePriority scheme == 2 && isJust (getRegexpMatch scheme) diff --git a/src/GEval/OptionsParser.hs b/src/GEval/OptionsParser.hs index d4be11e..fdddad9 100644 --- a/src/GEval/OptionsParser.hs +++ b/src/GEval/OptionsParser.hs @@ -35,6 +35,8 @@ import GEval.Submit (submit) import GEval.BlackBoxDebugging import GEval.Selector import GEval.Validation +import GEval.Model +import GEval.ModelTraining import Data.List (intercalate) @@ -99,6 +101,16 @@ optionsParser = GEvalOptions (flag' OracleItemBased ( long "oracle-item-based" <> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file).")) + <|> + (TrainModel <$> option auto + ( long "train" + <> metavar "MODEL-TYPE" + <> help "Train a model")) + <|> + (Infer <$> strOption + ( long "infer" + <> metavar "MODEL-PATH" + <> help "Infer from a model")) ) <*> ((flag' FirstTheWorst @@ -387,6 +399,12 @@ runGEval''' (Just ListMetrics) _ _ _ _ _ _ = do runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do runOracleItemBased spec return Nothing +runGEval''' (Just (TrainModel modelType)) _ _ spec _ _ _ = do + trainModel modelType spec + return Nothing +runGEval''' (Just (Infer modelPath)) _ _ spec _ _ _ = do + infer modelPath spec + return Nothing getGraphFilename :: Int -> FilePath -> FilePath getGraphFilename 0 fp = fp diff --git a/stack.yaml b/stack.yaml index 3805ab6..3784d5e 100644 --- a/stack.yaml +++ b/stack.yaml @@ -1,5 +1,5 @@ flags: {} packages: - '.' -extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810'] +extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810','aeson-yaml-1.0.4.0@sha256:72d91a4a2ade87b8a4bdf73937d2c62bd2c60053df1841f8bf1e6204387959b8,1975'] resolver: lts-12.26