Training a simple ensemble
This commit is contained in:
parent
8da6dac871
commit
0ed06729ba
@ -55,6 +55,8 @@ library
|
||||
, GEval.Formatting
|
||||
, Data.Conduit.Header
|
||||
, GEval.DataSource
|
||||
, GEval.Model
|
||||
, GEval.ModelTraining
|
||||
, Paths_geval
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, cond
|
||||
@ -111,6 +113,7 @@ library
|
||||
, ordered-containers
|
||||
, random
|
||||
, rainbow
|
||||
, yaml
|
||||
default-language: Haskell2010
|
||||
|
||||
executable geval
|
||||
|
@ -63,6 +63,7 @@ import Data.Singletons.TH
|
||||
import GEval.Metric
|
||||
import GEval.MetricsMechanics
|
||||
import GEval.EvaluationScheme
|
||||
import GEval.Model (ModelType)
|
||||
|
||||
import Data.Conduit
|
||||
import Data.Conduit.Combinators as CC
|
||||
@ -227,6 +228,8 @@ data GEvalSpecialCommand = Init
|
||||
| PrintVersion | JustTokenize | Submit
|
||||
| Validate | ListMetrics
|
||||
| OracleItemBased
|
||||
| TrainModel ModelType
|
||||
| Infer FilePath
|
||||
|
||||
data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest
|
||||
|
||||
|
@ -4,6 +4,7 @@ module GEval.EvaluationScheme
|
||||
applyPreprocessingOperations,
|
||||
evaluationSchemeName,
|
||||
evaluationSchemePriority,
|
||||
getRegexpMatch,
|
||||
PreprocessingOperation(..))
|
||||
where
|
||||
|
||||
@ -59,6 +60,12 @@ readOps ('s':theRest) = handleParametrizedBinaryOp (\a b -> RegexpSubstition (fr
|
||||
readOps ('f':theRest) = handleParametrizedOp (FeatureFilter . pack) theRest
|
||||
readOps s = ([], s)
|
||||
|
||||
getRegexpMatch :: EvaluationScheme -> Maybe Regex
|
||||
getRegexpMatch (EvaluationScheme _ ops) = getRegexpMatch' ops
|
||||
where getRegexpMatch' [] = Nothing
|
||||
getRegexpMatch' ((RegexpMatch regex):_) = Just regex
|
||||
getRegexpMatch' (_:ops) = getRegexpMatch' ops
|
||||
|
||||
handleParametrizedOp :: (String -> PreprocessingOperation) -> String -> ([PreprocessingOperation], String)
|
||||
handleParametrizedOp constructor theRest =
|
||||
case parseParameter theRest of
|
||||
|
@ -21,7 +21,8 @@ module GEval.LineByLine
|
||||
ResultOrdering(..),
|
||||
justTokenize,
|
||||
worstFeaturesPipeline,
|
||||
runOracleItemBased
|
||||
runOracleItemBased,
|
||||
runMultiOutputGeneralizedForEvaluationScheme
|
||||
) where
|
||||
|
||||
import GEval.Core
|
||||
@ -516,7 +517,11 @@ runOracleItemBased spec = runMultiOutputGeneralized spec consum
|
||||
metric = gesMainMetric spec
|
||||
|
||||
runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
|
||||
runMultiOutputGeneralized spec consum = do
|
||||
runMultiOutputGeneralized spec consum = runMultiOutputGeneralizedForEvaluationScheme spec mainScheme consum
|
||||
where mainScheme = gesMainScheme spec
|
||||
|
||||
runMultiOutputGeneralizedForEvaluationScheme :: GEvalSpecification -> EvaluationScheme -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
|
||||
runMultiOutputGeneralizedForEvaluationScheme spec scheme@(EvaluationScheme metric _) consum = do
|
||||
dataSource' <- checkAndGetDataSource True spec
|
||||
let dataSource = addSchemeSpecifics scheme dataSource'
|
||||
let (Just altOuts) = gesAltOutFiles spec
|
||||
@ -530,8 +535,6 @@ runMultiOutputGeneralized spec consum = do
|
||||
dataSourceOut = s}) sourceSpecs
|
||||
runResourceT $ runConduit $
|
||||
(sequenceSources sources .| consum)
|
||||
where metric = gesMainMetric spec
|
||||
scheme = gesMainScheme spec
|
||||
|
||||
runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
|
||||
runMostWorseningFeatures ordering otherOut spec bbdo = do
|
||||
|
48
src/GEval/Model.hs
Normal file
48
src/GEval/Model.hs
Normal file
@ -0,0 +1,48 @@
|
||||
{-# LANGUAGE OverloadedStrings #-}
|
||||
|
||||
module GEval.Model(
|
||||
ModelType(..),
|
||||
SimpleMixEnsembleModel(..),
|
||||
SimpleMixEnsembleModelRule(..))
|
||||
where
|
||||
|
||||
import qualified Data.ByteString.UTF8 as BSU
|
||||
import Text.Regex.PCRE.Heavy
|
||||
import Text.Regex.PCRE.Light.Base (Regex(..))
|
||||
|
||||
import qualified Data.Vector as V
|
||||
|
||||
import Data.Either (fromRight)
|
||||
|
||||
import Data.Yaml
|
||||
|
||||
data ModelType = SimpleMixEnsemble
|
||||
deriving (Show, Read, Eq)
|
||||
|
||||
data SimpleMixEnsembleModel = SimpleMixEnsembleModel [SimpleMixEnsembleModelRule]
|
||||
deriving (Show)
|
||||
|
||||
data SimpleMixEnsembleModelRule = SimpleMixEnsembleModelRule {
|
||||
simpleMixEnsembleModelRuleRegex :: Regex,
|
||||
simpleMixEnsembleModelRuleOut :: String }
|
||||
deriving (Eq, Show)
|
||||
|
||||
instance ToJSON SimpleMixEnsembleModelRule where
|
||||
toJSON rule = object ["regex" .= (formatRegexp $ simpleMixEnsembleModelRuleRegex rule),
|
||||
"out" .= simpleMixEnsembleModelRuleOut rule]
|
||||
|
||||
instance FromJSON SimpleMixEnsembleModelRule where
|
||||
parseJSON = withObject "SimpleMixEnsembleModelRule" $ \v -> SimpleMixEnsembleModelRule
|
||||
<$> (toRegexp <$> (v .: "regex"))
|
||||
<*> v .: "out"
|
||||
|
||||
instance FromJSON SimpleMixEnsembleModel where
|
||||
parseJSON = withArray "SimpleMixEnsembleModel" $ \arr ->
|
||||
SimpleMixEnsembleModel <$> mapM parseJSON (V.toList arr)
|
||||
|
||||
instance ToJSON SimpleMixEnsembleModel where
|
||||
toJSON (SimpleMixEnsembleModel rules) = toJSON rules
|
||||
|
||||
formatRegexp (Regex _ regexp) = BSU.toString regexp
|
||||
|
||||
toRegexp = (fromRight undefined) . ((flip compileM) []) . BSU.fromString
|
78
src/GEval/ModelTraining.hs
Normal file
78
src/GEval/ModelTraining.hs
Normal file
@ -0,0 +1,78 @@
|
||||
{-# LANGUAGE ScopedTypeVariables #-}
|
||||
|
||||
module GEval.ModelTraining(
|
||||
trainModel,
|
||||
infer)
|
||||
where
|
||||
|
||||
import GEval.Core
|
||||
import GEval.Model
|
||||
import GEval.Common
|
||||
|
||||
import System.FilePath
|
||||
import Data.Either
|
||||
import Data.Maybe
|
||||
|
||||
import Data.List (maximumBy, transpose)
|
||||
|
||||
import Data.Conduit.Header (processHeader)
|
||||
import Text.Regex.PCRE.Heavy
|
||||
import Text.Regex.PCRE.Light.Base (Regex(..))
|
||||
import Data.Conduit
|
||||
import Data.Conduit.AutoDecompress (autoDecompress)
|
||||
import qualified Data.Conduit.Text as CT
|
||||
import qualified Data.Conduit.Combinators as CC
|
||||
import Control.Monad.Trans.Resource
|
||||
import qualified Data.Text as T
|
||||
import qualified Data.Map as M
|
||||
|
||||
import qualified Data.ByteString.Lazy as BSL
|
||||
import qualified Data.ByteString as BS
|
||||
|
||||
import qualified Data.Yaml as Y
|
||||
|
||||
import Data.Conduit.SmartSource
|
||||
import GEval.EvaluationScheme
|
||||
|
||||
infer :: FilePath -> GEvalSpecification -> IO ()
|
||||
infer modelPath spec = do
|
||||
model :: SimpleMixEnsembleModel <- Y.decodeFileThrow modelPath
|
||||
outs <- fromJust <$> checkMultipleOuts spec
|
||||
outSpecs' <- mapM (getSmartSourceSpec ((gesOutDirectory spec) </> (gesTestName spec)) "out.tsv") outs
|
||||
let outSpecs = rights outSpecs'
|
||||
mHeader <- readHeaderFileWrapper headerSpec
|
||||
let bareFilesMap = M.fromList $ zip (map (\(FilePathSpec fp) -> takeFileName fp) outSpecs) [0..]
|
||||
|
||||
let sources = map (createSource mHeader) outSpecs
|
||||
runResourceT $ runConduit $ (sequenceSources sources .| (CC.map (runModel model bareFilesMap)) .| CC.unlines .| CC.encodeUtf8 .| CC.stdout)
|
||||
where createSource mHeader spec = (smartSource spec) .| autoDecompress .| CT.decodeUtf8Lenient .| CT.lines .| processHeader mHeader
|
||||
headerSpec = gesOutHeader spec
|
||||
|
||||
runModel :: SimpleMixEnsembleModel -> M.Map FilePath Int -> [T.Text] -> T.Text
|
||||
runModel (SimpleMixEnsembleModel rules) bareFilesMap ts = T.intercalate (T.pack " ") $ map applyRule rules
|
||||
where applyRule rule = applyRegex (simpleMixEnsembleModelRuleRegex rule) (ts !! (bareFilesMap M.! (simpleMixEnsembleModelRuleOut rule)))
|
||||
applyRegex regex = T.concat . (map fst) . (scan regex)
|
||||
|
||||
|
||||
trainModel :: ModelType -> GEvalSpecification -> IO ()
|
||||
trainModel SimpleMixEnsemble spec = do
|
||||
vals' <- geval spec
|
||||
let vals = map (\(s, val) -> (s, map getMetricValue val)) vals'
|
||||
|
||||
let byMetric = filter (\(s, _) -> isEligibleForSimpleMix s) $ zip (gesMetrics spec) $ transpose $ map (\(_, scores) -> scores) vals
|
||||
|
||||
let sources = map (\(FilePathSpec fp, _) -> takeFileName fp) vals
|
||||
|
||||
let bestOnes = map ((\(scheme, ix) -> SimpleMixEnsembleModelRule {
|
||||
simpleMixEnsembleModelRuleRegex = fromJust $ getRegexpMatch scheme,
|
||||
simpleMixEnsembleModelRuleOut = sources !! ix}) . selectBest) byMetric
|
||||
|
||||
let model = SimpleMixEnsembleModel bestOnes
|
||||
let modelContent = Y.encode model
|
||||
BS.putStr modelContent
|
||||
where selectBest (scheme, scores) = (scheme, bestIx)
|
||||
where bestIx = fst $ maximumBy (\(_, SimpleRun s) (_, SimpleRun s') -> s `compare` s') $ zip [0..] scores
|
||||
|
||||
isEligibleForSimpleMix :: EvaluationScheme -> Bool
|
||||
isEligibleForSimpleMix scheme =
|
||||
evaluationSchemePriority scheme == 2 && isJust (getRegexpMatch scheme)
|
@ -35,6 +35,8 @@ import GEval.Submit (submit)
|
||||
import GEval.BlackBoxDebugging
|
||||
import GEval.Selector
|
||||
import GEval.Validation
|
||||
import GEval.Model
|
||||
import GEval.ModelTraining
|
||||
|
||||
import Data.List (intercalate)
|
||||
|
||||
@ -99,6 +101,16 @@ optionsParser = GEvalOptions
|
||||
(flag' OracleItemBased
|
||||
( long "oracle-item-based"
|
||||
<> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file)."))
|
||||
<|>
|
||||
(TrainModel <$> option auto
|
||||
( long "train"
|
||||
<> metavar "MODEL-TYPE"
|
||||
<> help "Train a model"))
|
||||
<|>
|
||||
(Infer <$> strOption
|
||||
( long "infer"
|
||||
<> metavar "MODEL-PATH"
|
||||
<> help "Infer from a model"))
|
||||
)
|
||||
|
||||
<*> ((flag' FirstTheWorst
|
||||
@ -387,6 +399,12 @@ runGEval''' (Just ListMetrics) _ _ _ _ _ _ = do
|
||||
runGEval''' (Just OracleItemBased) _ _ spec _ _ _ = do
|
||||
runOracleItemBased spec
|
||||
return Nothing
|
||||
runGEval''' (Just (TrainModel modelType)) _ _ spec _ _ _ = do
|
||||
trainModel modelType spec
|
||||
return Nothing
|
||||
runGEval''' (Just (Infer modelPath)) _ _ spec _ _ _ = do
|
||||
infer modelPath spec
|
||||
return Nothing
|
||||
|
||||
getGraphFilename :: Int -> FilePath -> FilePath
|
||||
getGraphFilename 0 fp = fp
|
||||
|
@ -1,5 +1,5 @@
|
||||
flags: {}
|
||||
packages:
|
||||
- '.'
|
||||
extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810']
|
||||
extra-deps: [murmur3-1.0.3,naturalcomp-0.0.3,Munkres-0.1,numeric-tools-0.2.0.1,Chart-1.9.1,Chart-cairo-1.9.1,multiset-0.3.4.1,'ordered-containers-0.2.2@sha256:ebf2be3f592d9cf148ea6b8375f8af97148d44f82d8d04476899285e965afdbf,810','aeson-yaml-1.0.4.0@sha256:72d91a4a2ade87b8a4bdf73937d2c62bd2c60053df1841f8bf1e6204387959b8,1975']
|
||||
resolver: lts-12.26
|
||||
|
Loading…
Reference in New Issue
Block a user