Training a simple ensemble

This commit is contained in:
Filip Gralinski 2020-05-13 16:45:15 +02:00
parent 8da6dac871
commit 0ed06729ba
8 changed files with 165 additions and 5 deletions

View File

@ -55,6 +55,8 @@ library
, GEval.Formatting , GEval.Formatting
, Data.Conduit.Header , Data.Conduit.Header
, GEval.DataSource , GEval.DataSource
, GEval.Model
, GEval.ModelTraining
, Paths_geval , Paths_geval
build-depends: base >= 4.7 && < 5 build-depends: base >= 4.7 && < 5
, cond , cond
@ -111,6 +113,7 @@ library
, ordered-containers , ordered-containers
, random , random
, rainbow , rainbow
, yaml
default-language: Haskell2010 default-language: Haskell2010
executable geval executable geval

View File

@ -63,6 +63,7 @@ import Data.Singletons.TH
import GEval.Metric import GEval.Metric
import GEval.MetricsMechanics import GEval.MetricsMechanics
import GEval.EvaluationScheme import GEval.EvaluationScheme
import GEval.Model (ModelType)
import Data.Conduit import Data.Conduit
import Data.Conduit.Combinators as CC import Data.Conduit.Combinators as CC
@ -227,6 +228,8 @@ data GEvalSpecialCommand = Init
| PrintVersion | JustTokenize | Submit | PrintVersion | JustTokenize | Submit
| Validate | ListMetrics | Validate | ListMetrics
| OracleItemBased | OracleItemBased
| TrainModel ModelType
| Infer FilePath
data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest

View File

@ -4,6 +4,7 @@ module GEval.EvaluationScheme
applyPreprocessingOperations, applyPreprocessingOperations,
evaluationSchemeName, evaluationSchemeName,
evaluationSchemePriority, evaluationSchemePriority,
getRegexpMatch,
PreprocessingOperation(..)) PreprocessingOperation(..))
where where
@ -59,6 +60,12 @@ readOps ('s':theRest) = handleParametrizedBinaryOp (\a b -> RegexpSubstition (fr
readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest
readOps s = ([], s) readOps s = ([], s)
getRegexpMatch :: EvaluationScheme -> Maybe Regex
getRegexpMatch (EvaluationScheme _ ops) = getRegexpMatch' ops
where getRegexpMatch' [] = Nothing
getRegexpMatch' ((RegexpMatch regex):_) = Just regex
getRegexpMatch' (_:ops) = getRegexpMatch' ops
handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String) handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String)
handleParametrizedOp constructor theRest = handleParametrizedOp constructor theRest =
case parseParameter theRest of case parseParameter theRest of

View File

@ -21,7 +21,8 @@ module GEval.LineByLine
ResultOrdering(..), ResultOrdering(..),
justTokenize, justTokenize,
worstFeaturesPipeline, worstFeaturesPipeline,
runOracleItemBased runOracleItemBased,
runMultiOutputGeneralizedForEvaluationScheme
) where ) where
import GEval.Core import GEval.Core
@ -516,7 +517,11 @@ runOracleItemBased spec = runMultiOutputGeneralized spec consum
metric = gesMainMetric spec metric = gesMainMetric spec
runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO () runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
runMultiOutputGeneralized spec consum = do runMultiOutputGeneralized spec consum = runMultiOutputGeneralizedForEvaluationScheme spec mainScheme consum
where mainScheme = gesMainScheme spec
runMultiOutputGeneralizedForEvaluationScheme :: GEvalSpecification -> EvaluationScheme -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
runMultiOutputGeneralizedForEvaluationScheme spec scheme@(EvaluationScheme metric _) consum = do
dataSource' <- checkAndGetDataSource True spec dataSource' <- checkAndGetDataSource True spec
let dataSource = addSchemeSpecifics scheme dataSource' let dataSource = addSchemeSpecifics scheme dataSource'
let (Just altOuts) = gesAltOutFiles spec let (Just altOuts) = gesAltOutFiles spec
@ -530,8 +535,6 @@ runMultiOutputGeneralized spec consum = do
dataSourceOut = s}) sourceSpecs dataSourceOut = s}) sourceSpecs
runResourceT $ runConduit $ runResourceT $ runConduit $
(sequenceSources sources .| consum) (sequenceSources sources .| consum)
where metric = gesMainMetric spec
scheme = gesMainScheme spec
runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO () runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
runMostWorseningFeatures ordering otherOut spec bbdo = do runMostWorseningFeatures ordering otherOut spec bbdo = do

48
src/GEval/Model.hs Normal file
View File

@ -0,0 +1,48 @@
{-# LANGUAGE OverloadedStrings #-}
module GEval.Model(
ModelType(..),
SimpleMixEnsembleModel(..),
SimpleMixEnsembleModelRule(..))
where
import qualified Data.ByteString.UTF8 as BSU
import Text.Regex.PCRE.Heavy
import Text.Regex.PCRE.Light.Base (Regex(..))
import qualified Data.Vector as V
import Data.Either (fromRight)
import Data.Yaml
data ModelType = SimpleMixEnsemble
deriving (Show, Read, Eq)
data SimpleMixEnsembleModel = SimpleMixEnsembleModel [SimpleMixEnsembleModelRule]
deriving (Show)
data SimpleMixEnsembleModelRule = SimpleMixEnsembleModelRule {
simpleMixEnsembleModelRuleRegex :: Regex,
simpleMixEnsembleModelRuleOut :: String }
deriving (Eq, Show)
instance ToJSON SimpleMixEnsembleModelRule where
toJSON rule = object ["regex" .= (formatRegexp $ simpleMixEnsembleModelRuleRegex rule),
"out" .= simpleMixEnsembleModelRuleOut rule]
instance FromJSON SimpleMixEnsembleModelRule where
parseJSON = withObject "SimpleMixEnsembleModelRule" $ \v -> SimpleMixEnsembleModelRule
<$> (toRegexp <$> (v .: "regex"))
<*> v .: "out"
instance FromJSON SimpleMixEnsembleModel where
parseJSON = withArray "SimpleMixEnsembleModel" $ \arr ->
SimpleMixEnsembleModel <$> mapM parseJSON (V.toList arr)
instance ToJSON SimpleMixEnsembleModel where
toJSON (SimpleMixEnsembleModel rules) = toJSON rules
formatRegexp (Regex _ regexp) = BSU.toString regexp
toRegexp = (fromRight undefined) . ((flip compileM) []) . BSU.fromString

View File

@ -0,0 +1,78 @@
{-# LANGUAGE ScopedTypeVariables #-}
module GEval.ModelTraining(
trainModel,
infer)
where
import GEval.Core
import GEval.Model
import GEval.Common
import System.FilePath
import Data.Either
import Data.Maybe
import Data.List (maximumBy, transpose)
import Data.Conduit.Header (processHeader)
import Text.Regex.PCRE.Heavy
import Text.Regex.PCRE.Light.Base (Regex(..))
import Data.Conduit
import Data.Conduit.AutoDecompress (autoDecompress)
import qualified Data.Conduit.Text as CT
import qualified Data.Conduit.Combinators as CC
import Control.Monad.Trans.Resource
import qualified Data.Text as T
import qualified Data.Map as M
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString as BS
import qualified Data.Yaml as Y
import Data.Conduit.SmartSource
import GEval.EvaluationScheme
infer :: FilePath -> GEvalSpecification -> IO ()
infer modelPath spec = do
model :: SimpleMixEnsembleModel <- Y.decodeFileThrow modelPath
outs <- fromJust <$> checkMultipleOuts spec
outSpecs' <- mapM (getSmartSourceSpec ((gesOutDirectory spec) </> (gesTestName spec)) "out.tsv") outs
let outSpecs = rights outSpecs'
mHeader <- readHeaderFileWrapper headerSpec
let bareFilesMap = M.fromList $ zip (map (\(FilePathSpec fp) -> takeFileName fp) outSpecs) [0..]
let sources = map (createSource mHeader) outSpecs
runResourceT $ runConduit $ (sequenceSources sources .| (CC.map (runModel model bareFilesMap)) .| CC.unlines .| CC.encodeUtf8 .| CC.stdout)
where createSource mHeader spec = (smartSource spec) .| autoDecompress .| CT.decodeUtf8Lenient .| CT.lines .| processHeader mHeader
headerSpec = gesOutHeader spec
runModel :: SimpleMixEnsembleModel -> M.Map FilePath Int -> [T.Text] -> T.Text
runModel (SimpleMixEnsembleModel rules) bareFilesMap ts = T.intercalate (T.pack " ") $ map applyRule rules
where applyRule rule = applyRegex (simpleMixEnsembleModelRuleRegex rule) (ts !! (bareFilesMap M.! (simpleMixEnsembleModelRuleOut rule)))
applyRegex regex = T.concat . (map fst) . (scan regex)
trainModel :: ModelType -> GEvalSpecification -> IO ()
trainModel SimpleMixEnsemble spec = do
vals' <- geval spec
let vals = map (\(s, val) -> (s, map getMetricValue val)) vals'
let byMetric = filter (\(s, _) -> isEligibleForSimpleMix s) $ zip (gesMetrics spec) $ transpose $ map (\(_, scores) -> scores) vals
let sources = map (\(FilePathSpec fp, _) -> takeFileName fp) vals
let bestOnes = map ((\(scheme, ix) -> SimpleMixEnsembleModelRule {
simpleMixEnsembleModelRuleRegex = fromJust $ getRegexpMatch scheme,
simpleMixEnsembleModelRuleOut = sources !! ix}) . selectBest) byMetric
let model = SimpleMixEnsembleModel bestOnes
let modelContent = Y.encode model
BS.putStr modelContent
where selectBest (scheme, scores) = (scheme, bestIx)
where bestIx = fst $ maximumBy (\(_, SimpleRun s) (_, SimpleRun s') -> s `compare` s') $ zip [0..] scores
isEligibleForSimpleMix :: EvaluationScheme -> Bool
isEligibleForSimpleMix scheme =
evaluationSchemePriority scheme == 2 && isJust (getRegexpMatch scheme)

View File

@ -35,6 +35,8 @@ import GEval.Submit (submit)
import GEval.BlackBoxDebugging import GEval.BlackBoxDebugging
import GEval.Selector import GEval.Selector
import GEval.Validation import GEval.Validation
import GEval.Model
import GEval.ModelTraining
import Data.List (intercalate) import Data.List (intercalate)
@ -99,6 +101,16 @@ optionsParser = GEvalOptions
(flag' OracleItemBased (flag' OracleItemBased
( long "oracle-item-based" ( long "oracle-item-based"
<> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file).")) <> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file)."))
<|>
(TrainModel <$> option auto
( long "train"
<> metavar "MODEL-TYPE"
<> help "Train a model"))
<|>
(Infer <$> strOption
( long "infer"
<> metavar "MODEL-PATH"
<> help "Infer from a model"))
) )
<*> ((flag' FirstTheWorst <*> ((flag' FirstTheWorst
@ -387,6 +399,12 @@ runGEval''' (Just ListMetrics) _ _ _ _ _ _ = do
runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do
runOracleItemBased spec runOracleItemBased spec
return Nothing return Nothing
runGEval''' (Just (TrainModel modelType)) _ _ spec _ _ _ = do
trainModel modelType spec
return Nothing
runGEval''' (Just (Infer modelPath)) _ _ spec _ _ _ = do
infer modelPath spec
return Nothing
getGraphFilename :: Int -> FilePath -> FilePath getGraphFilename :: Int -> FilePath -> FilePath
getGraphFilename 0 fp = fp getGraphFilename 0 fp = fp

View File

@ -1,5 +1,5 @@
flags: {} flags: {}
packages: packages:
- '.' - '.'
extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810'] extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810','aeson-yaml-1.0.4.0@sha256:72d91a4a2ade87b8a4bdf73937d2c62bd2c60053df1841f8bf1e6204387959b8,1975']
resolver: lts-12.26 resolver: lts-12.26