Training a simple ensemble
This commit is contained in:
parent
8da6dac871
commit
0ed06729ba
@ -55,6 +55,8 @@ library
|
|||||||
, GEval.Formatting
|
, GEval.Formatting
|
||||||
, Data.Conduit.Header
|
, Data.Conduit.Header
|
||||||
, GEval.DataSource
|
, GEval.DataSource
|
||||||
|
, GEval.Model
|
||||||
|
, GEval.ModelTraining
|
||||||
, Paths_geval
|
, Paths_geval
|
||||||
build-depends: base >= 4.7 && < 5
|
build-depends: base >= 4.7 && < 5
|
||||||
, cond
|
, cond
|
||||||
@ -111,6 +113,7 @@ library
|
|||||||
, ordered-containers
|
, ordered-containers
|
||||||
, random
|
, random
|
||||||
, rainbow
|
, rainbow
|
||||||
|
, yaml
|
||||||
default-language: Haskell2010
|
default-language: Haskell2010
|
||||||
|
|
||||||
executable geval
|
executable geval
|
||||||
|
@ -63,6 +63,7 @@ import Data.Singletons.TH
|
|||||||
import GEval.Metric
|
import GEval.Metric
|
||||||
import GEval.MetricsMechanics
|
import GEval.MetricsMechanics
|
||||||
import GEval.EvaluationScheme
|
import GEval.EvaluationScheme
|
||||||
|
import GEval.Model (ModelType)
|
||||||
|
|
||||||
import Data.Conduit
|
import Data.Conduit
|
||||||
import Data.Conduit.Combinators as CC
|
import Data.Conduit.Combinators as CC
|
||||||
@ -227,6 +228,8 @@ data GEvalSpecialCommand = Init
|
|||||||
| PrintVersion | JustTokenize | Submit
|
| PrintVersion | JustTokenize | Submit
|
||||||
| Validate | ListMetrics
|
| Validate | ListMetrics
|
||||||
| OracleItemBased
|
| OracleItemBased
|
||||||
|
| TrainModel ModelType
|
||||||
|
| Infer FilePath
|
||||||
|
|
||||||
data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest
|
data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest
|
||||||
|
|
||||||
|
@ -4,6 +4,7 @@ module GEval.EvaluationScheme
|
|||||||
applyPreprocessingOperations,
|
applyPreprocessingOperations,
|
||||||
evaluationSchemeName,
|
evaluationSchemeName,
|
||||||
evaluationSchemePriority,
|
evaluationSchemePriority,
|
||||||
|
getRegexpMatch,
|
||||||
PreprocessingOperation(..))
|
PreprocessingOperation(..))
|
||||||
where
|
where
|
||||||
|
|
||||||
@ -59,6 +60,12 @@ readOps ('s':theRest) = handleParametrizedBinaryOp (\a b -> RegexpSubstition (fr
|
|||||||
readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest
|
readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest
|
||||||
readOps s = ([], s)
|
readOps s = ([], s)
|
||||||
|
|
||||||
|
getRegexpMatch :: EvaluationScheme -> Maybe Regex
|
||||||
|
getRegexpMatch (EvaluationScheme _ ops) = getRegexpMatch' ops
|
||||||
|
where getRegexpMatch' [] = Nothing
|
||||||
|
getRegexpMatch' ((RegexpMatch regex):_) = Just regex
|
||||||
|
getRegexpMatch' (_:ops) = getRegexpMatch' ops
|
||||||
|
|
||||||
handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String)
|
handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String)
|
||||||
handleParametrizedOp constructor theRest =
|
handleParametrizedOp constructor theRest =
|
||||||
case parseParameter theRest of
|
case parseParameter theRest of
|
||||||
|
@ -21,7 +21,8 @@ module GEval.LineByLine
|
|||||||
ResultOrdering(..),
|
ResultOrdering(..),
|
||||||
justTokenize,
|
justTokenize,
|
||||||
worstFeaturesPipeline,
|
worstFeaturesPipeline,
|
||||||
runOracleItemBased
|
runOracleItemBased,
|
||||||
|
runMultiOutputGeneralizedForEvaluationScheme
|
||||||
) where
|
) where
|
||||||
|
|
||||||
import GEval.Core
|
import GEval.Core
|
||||||
@ -516,7 +517,11 @@ runOracleItemBased spec = runMultiOutputGeneralized spec consum
|
|||||||
metric = gesMainMetric spec
|
metric = gesMainMetric spec
|
||||||
|
|
||||||
runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
|
runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
|
||||||
runMultiOutputGeneralized spec consum = do
|
runMultiOutputGeneralized spec consum = runMultiOutputGeneralizedForEvaluationScheme spec mainScheme consum
|
||||||
|
where mainScheme = gesMainScheme spec
|
||||||
|
|
||||||
|
runMultiOutputGeneralizedForEvaluationScheme :: GEvalSpecification -> EvaluationScheme -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
|
||||||
|
runMultiOutputGeneralizedForEvaluationScheme spec scheme@(EvaluationScheme metric _) consum = do
|
||||||
dataSource' <- checkAndGetDataSource True spec
|
dataSource' <- checkAndGetDataSource True spec
|
||||||
let dataSource = addSchemeSpecifics scheme dataSource'
|
let dataSource = addSchemeSpecifics scheme dataSource'
|
||||||
let (Just altOuts) = gesAltOutFiles spec
|
let (Just altOuts) = gesAltOutFiles spec
|
||||||
@ -530,8 +535,6 @@ runMultiOutputGeneralized spec consum = do
|
|||||||
dataSourceOut = s}) sourceSpecs
|
dataSourceOut = s}) sourceSpecs
|
||||||
runResourceT $ runConduit $
|
runResourceT $ runConduit $
|
||||||
(sequenceSources sources .| consum)
|
(sequenceSources sources .| consum)
|
||||||
where metric = gesMainMetric spec
|
|
||||||
scheme = gesMainScheme spec
|
|
||||||
|
|
||||||
runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
|
runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
|
||||||
runMostWorseningFeatures ordering otherOut spec bbdo = do
|
runMostWorseningFeatures ordering otherOut spec bbdo = do
|
||||||
|
48
src/GEval/Model.hs
Normal file
48
src/GEval/Model.hs
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
|
||||||
|
module GEval.Model(
|
||||||
|
ModelType(..),
|
||||||
|
SimpleMixEnsembleModel(..),
|
||||||
|
SimpleMixEnsembleModelRule(..))
|
||||||
|
where
|
||||||
|
|
||||||
|
import qualified Data.ByteString.UTF8 as BSU
|
||||||
|
import Text.Regex.PCRE.Heavy
|
||||||
|
import Text.Regex.PCRE.Light.Base (Regex(..))
|
||||||
|
|
||||||
|
import qualified Data.Vector as V
|
||||||
|
|
||||||
|
import Data.Either (fromRight)
|
||||||
|
|
||||||
|
import Data.Yaml
|
||||||
|
|
||||||
|
data ModelType = SimpleMixEnsemble
|
||||||
|
deriving (Show, Read, Eq)
|
||||||
|
|
||||||
|
data SimpleMixEnsembleModel = SimpleMixEnsembleModel [SimpleMixEnsembleModelRule]
|
||||||
|
deriving (Show)
|
||||||
|
|
||||||
|
data SimpleMixEnsembleModelRule = SimpleMixEnsembleModelRule {
|
||||||
|
simpleMixEnsembleModelRuleRegex :: Regex,
|
||||||
|
simpleMixEnsembleModelRuleOut :: String }
|
||||||
|
deriving (Eq, Show)
|
||||||
|
|
||||||
|
instance ToJSON SimpleMixEnsembleModelRule where
|
||||||
|
toJSON rule = object ["regex" .= (formatRegexp $ simpleMixEnsembleModelRuleRegex rule),
|
||||||
|
"out" .= simpleMixEnsembleModelRuleOut rule]
|
||||||
|
|
||||||
|
instance FromJSON SimpleMixEnsembleModelRule where
|
||||||
|
parseJSON = withObject "SimpleMixEnsembleModelRule" $ \v -> SimpleMixEnsembleModelRule
|
||||||
|
<$> (toRegexp <$> (v .: "regex"))
|
||||||
|
<*> v .: "out"
|
||||||
|
|
||||||
|
instance FromJSON SimpleMixEnsembleModel where
|
||||||
|
parseJSON = withArray "SimpleMixEnsembleModel" $ \arr ->
|
||||||
|
SimpleMixEnsembleModel <$> mapM parseJSON (V.toList arr)
|
||||||
|
|
||||||
|
instance ToJSON SimpleMixEnsembleModel where
|
||||||
|
toJSON (SimpleMixEnsembleModel rules) = toJSON rules
|
||||||
|
|
||||||
|
formatRegexp (Regex _ regexp) = BSU.toString regexp
|
||||||
|
|
||||||
|
toRegexp = (fromRight undefined) . ((flip compileM) []) . BSU.fromString
|
78
src/GEval/ModelTraining.hs
Normal file
78
src/GEval/ModelTraining.hs
Normal file
@ -0,0 +1,78 @@
|
|||||||
|
{-# LANGUAGE ScopedTypeVariables #-}
|
||||||
|
|
||||||
|
module GEval.ModelTraining(
|
||||||
|
trainModel,
|
||||||
|
infer)
|
||||||
|
where
|
||||||
|
|
||||||
|
import GEval.Core
|
||||||
|
import GEval.Model
|
||||||
|
import GEval.Common
|
||||||
|
|
||||||
|
import System.FilePath
|
||||||
|
import Data.Either
|
||||||
|
import Data.Maybe
|
||||||
|
|
||||||
|
import Data.List (maximumBy, transpose)
|
||||||
|
|
||||||
|
import Data.Conduit.Header (processHeader)
|
||||||
|
import Text.Regex.PCRE.Heavy
|
||||||
|
import Text.Regex.PCRE.Light.Base (Regex(..))
|
||||||
|
import Data.Conduit
|
||||||
|
import Data.Conduit.AutoDecompress (autoDecompress)
|
||||||
|
import qualified Data.Conduit.Text as CT
|
||||||
|
import qualified Data.Conduit.Combinators as CC
|
||||||
|
import Control.Monad.Trans.Resource
|
||||||
|
import qualified Data.Text as T
|
||||||
|
import qualified Data.Map as M
|
||||||
|
|
||||||
|
import qualified Data.ByteString.Lazy as BSL
|
||||||
|
import qualified Data.ByteString as BS
|
||||||
|
|
||||||
|
import qualified Data.Yaml as Y
|
||||||
|
|
||||||
|
import Data.Conduit.SmartSource
|
||||||
|
import GEval.EvaluationScheme
|
||||||
|
|
||||||
|
infer :: FilePath -> GEvalSpecification -> IO ()
|
||||||
|
infer modelPath spec = do
|
||||||
|
model :: SimpleMixEnsembleModel <- Y.decodeFileThrow modelPath
|
||||||
|
outs <- fromJust <$> checkMultipleOuts spec
|
||||||
|
outSpecs' <- mapM (getSmartSourceSpec ((gesOutDirectory spec) </> (gesTestName spec)) "out.tsv") outs
|
||||||
|
let outSpecs = rights outSpecs'
|
||||||
|
mHeader <- readHeaderFileWrapper headerSpec
|
||||||
|
let bareFilesMap = M.fromList $ zip (map (\(FilePathSpec fp) -> takeFileName fp) outSpecs) [0..]
|
||||||
|
|
||||||
|
let sources = map (createSource mHeader) outSpecs
|
||||||
|
runResourceT $ runConduit $ (sequenceSources sources .| (CC.map (runModel model bareFilesMap)) .| CC.unlines .| CC.encodeUtf8 .| CC.stdout)
|
||||||
|
where createSource mHeader spec = (smartSource spec) .| autoDecompress .| CT.decodeUtf8Lenient .| CT.lines .| processHeader mHeader
|
||||||
|
headerSpec = gesOutHeader spec
|
||||||
|
|
||||||
|
runModel :: SimpleMixEnsembleModel -> M.Map FilePath Int -> [T.Text] -> T.Text
|
||||||
|
runModel (SimpleMixEnsembleModel rules) bareFilesMap ts = T.intercalate (T.pack " ") $ map applyRule rules
|
||||||
|
where applyRule rule = applyRegex (simpleMixEnsembleModelRuleRegex rule) (ts !! (bareFilesMap M.! (simpleMixEnsembleModelRuleOut rule)))
|
||||||
|
applyRegex regex = T.concat . (map fst) . (scan regex)
|
||||||
|
|
||||||
|
|
||||||
|
trainModel :: ModelType -> GEvalSpecification -> IO ()
|
||||||
|
trainModel SimpleMixEnsemble spec = do
|
||||||
|
vals' <- geval spec
|
||||||
|
let vals = map (\(s, val) -> (s, map getMetricValue val)) vals'
|
||||||
|
|
||||||
|
let byMetric = filter (\(s, _) -> isEligibleForSimpleMix s) $ zip (gesMetrics spec) $ transpose $ map (\(_, scores) -> scores) vals
|
||||||
|
|
||||||
|
let sources = map (\(FilePathSpec fp, _) -> takeFileName fp) vals
|
||||||
|
|
||||||
|
let bestOnes = map ((\(scheme, ix) -> SimpleMixEnsembleModelRule {
|
||||||
|
simpleMixEnsembleModelRuleRegex = fromJust $ getRegexpMatch scheme,
|
||||||
|
simpleMixEnsembleModelRuleOut = sources !! ix}) . selectBest) byMetric
|
||||||
|
|
||||||
|
let model = SimpleMixEnsembleModel bestOnes
|
||||||
|
let modelContent = Y.encode model
|
||||||
|
BS.putStr modelContent
|
||||||
|
where selectBest (scheme, scores) = (scheme, bestIx)
|
||||||
|
where bestIx = fst $ maximumBy (\(_, SimpleRun s) (_, SimpleRun s') -> s `compare` s') $ zip [0..] scores
|
||||||
|
|
||||||
|
isEligibleForSimpleMix :: EvaluationScheme -> Bool
|
||||||
|
isEligibleForSimpleMix scheme =
|
||||||
|
evaluationSchemePriority scheme == 2 && isJust (getRegexpMatch scheme)
|
@ -35,6 +35,8 @@ import GEval.Submit (submit)
|
|||||||
import GEval.BlackBoxDebugging
|
import GEval.BlackBoxDebugging
|
||||||
import GEval.Selector
|
import GEval.Selector
|
||||||
import GEval.Validation
|
import GEval.Validation
|
||||||
|
import GEval.Model
|
||||||
|
import GEval.ModelTraining
|
||||||
|
|
||||||
import Data.List (intercalate)
|
import Data.List (intercalate)
|
||||||
|
|
||||||
@ -99,6 +101,16 @@ optionsParser = GEvalOptions
|
|||||||
(flag' OracleItemBased
|
(flag' OracleItemBased
|
||||||
( long "oracle-item-based"
|
( long "oracle-item-based"
|
||||||
<> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file)."))
|
<> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file)."))
|
||||||
|
<|>
|
||||||
|
(TrainModel <$> option auto
|
||||||
|
( long "train"
|
||||||
|
<> metavar "MODEL-TYPE"
|
||||||
|
<> help "Train a model"))
|
||||||
|
<|>
|
||||||
|
(Infer <$> strOption
|
||||||
|
( long "infer"
|
||||||
|
<> metavar "MODEL-PATH"
|
||||||
|
<> help "Infer from a model"))
|
||||||
)
|
)
|
||||||
|
|
||||||
<*> ((flag' FirstTheWorst
|
<*> ((flag' FirstTheWorst
|
||||||
@ -387,6 +399,12 @@ runGEval''' (Just ListMetrics) _ _ _ _ _ _ = do
|
|||||||
runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do
|
runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do
|
||||||
runOracleItemBased spec
|
runOracleItemBased spec
|
||||||
return Nothing
|
return Nothing
|
||||||
|
runGEval''' (Just (TrainModel modelType)) _ _ spec _ _ _ = do
|
||||||
|
trainModel modelType spec
|
||||||
|
return Nothing
|
||||||
|
runGEval''' (Just (Infer modelPath)) _ _ spec _ _ _ = do
|
||||||
|
infer modelPath spec
|
||||||
|
return Nothing
|
||||||
|
|
||||||
getGraphFilename :: Int -> FilePath -> FilePath
|
getGraphFilename :: Int -> FilePath -> FilePath
|
||||||
getGraphFilename 0 fp = fp
|
getGraphFilename 0 fp = fp
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
flags: {}
|
flags: {}
|
||||||
packages:
|
packages:
|
||||||
- '.'
|
- '.'
|
||||||
extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810']
|
extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810','aeson-yaml-1.0.4.0@sha256:72d91a4a2ade87b8a4bdf73937d2c62bd2c60053df1841f8bf1e6204387959b8,1975']
|
||||||
resolver: lts-12.26
|
resolver: lts-12.26
|
||||||
|
Loading…
Reference in New Issue
Block a user