Add --oracle-item-based option

This commit is contained in:
Filip Graliński 2019-12-16 11:17:22 +01:00
parent 1e74174c0b
commit 9a3a28a813
12 changed files with 75 additions and 4 deletions

View File

@ -160,6 +160,7 @@ data GEvalSpecification = GEvalSpecification
gesTestName :: String,
gesSelector :: Maybe Selector,
gesOutFile :: String,
gesAltOutFiles :: Maybe [String],
gesExpectedFile :: String,
gesInputFile :: String,
gesMetrics :: [EvaluationScheme],
@ -190,6 +191,7 @@ data GEvalSpecialCommand = Init
| Diff FilePath | MostWorseningFeatures FilePath
| PrintVersion | JustTokenize | Submit
| Validate | ListMetrics
| OracleItemBased
data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest
@ -249,6 +251,7 @@ defaultGEvalSpecification = GEvalSpecification {
gesTestName = defaultTestName,
gesSelector = Nothing,
gesOutFile = defaultOutFile,
gesAltOutFiles = Nothing,
gesExpectedFile = defaultExpectedFile,
gesInputFile = defaultInputFile,
gesMetrics = [EvaluationScheme defaultMetric []],

View File

@ -17,7 +17,8 @@ module GEval.LineByLine
LineRecord(..),
ResultOrdering(..),
justTokenize,
worstFeaturesPipeline
worstFeaturesPipeline,
runOracleItemBased
) where
import GEval.Core
@ -34,10 +35,11 @@ import Data.Text
import Data.Text.Encoding
import Data.Conduit.Rank
import Data.Maybe (fromMaybe)
import Data.Either (rights)
import qualified Data.Vector as V
import Data.List (sortBy, sortOn, sort, concat)
import Data.List (sortBy, sortOn, sort, concat, maximumBy)
import Control.Monad.IO.Class
import Control.Monad.Trans.Resource
@ -84,7 +86,6 @@ parseReferenceEntry :: Text -> (Integer, Text)
parseReferenceEntry line = (read $ unpack refId, t)
where [refId, t] = splitOn "\t" line
runLineByLine :: ResultOrdering -> Maybe String -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
runLineByLine ordering featureFilter spec bbdo = runLineByLineGeneralized ordering spec consum
where consum :: Maybe References -> ConduitT LineRecord Void (ResourceT IO) ()
@ -392,6 +393,30 @@ runDiff ordering featureFilter otherOut spec bbdo = runDiffGeneralized ordering
formatScoreDiff :: Double -> Text
formatScoreDiff = Data.Text.pack . printf "%f"
runOracleItemBased :: GEvalSpecification -> IO ()
runOracleItemBased spec = runMultiOutputGeneralized spec consum
where consum = CL.map picker .| format
picker = maximumBy (\(LineRecord _ _ _ _ scoreA) (LineRecord _ _ _ _ scoreB) -> metricCompare metric scoreA scoreB)
format = CL.map (encodeUtf8 . formatOutput)
.| CC.unlinesAscii
.| CC.stdout
formatOutput (LineRecord _ _ out _ _) = out
metric = gesMainMetric spec
runMultiOutputGeneralized :: GEvalSpecification -> ConduitT [LineRecord] Void (ResourceT IO) () -> IO ()
runMultiOutputGeneralized spec consum = do
(inputSource, expectedSource, outSource) <- checkAndGetFilesSingleOut True spec
let (Just altOuts) = gesAltOutFiles spec
altSourceSpecs' <- mapM (getSmartSourceSpec ((gesOutDirectory spec) </> (gesTestName spec)) "out.tsv") altOuts
let altSourceSpecs = rights altSourceSpecs'
let sourceSpecs = (outSource:altSourceSpecs)
let sources = Prelude.map (gevalLineByLineSource metric mSelector preprocess inputSource expectedSource) sourceSpecs
runResourceT $ runConduit $
(sequenceSources sources .| consum)
where metric = gesMainMetric spec
preprocess = gesPreprocess spec
mSelector = gesSelector spec
runMostWorseningFeatures :: ResultOrdering -> FilePath -> GEvalSpecification -> BlackBoxDebuggingOptions -> IO ()
runMostWorseningFeatures ordering otherOut spec bbdo = runDiffGeneralized ordering' otherOut spec consum
where ordering' = forceSomeOrdering ordering

View File

@ -8,7 +8,8 @@ module GEval.Metric
bestPossibleValue,
perfectOutLineFromExpectedLine,
fixedNumberOfColumnsInExpected,
fixedNumberOfColumnsInInput)
fixedNumberOfColumnsInInput,
metricCompare)
where
import Data.Word
@ -173,6 +174,11 @@ getMetricOrdering MultiLabelLogLoss = TheLowerTheBetter
getMetricOrdering MultiLabelLikelihood = TheHigherTheBetter
getMetricOrdering (Mean metric) = getMetricOrdering metric
metricCompare :: Metric -> MetricValue -> MetricValue -> Ordering
metricCompare metric a b = metricCompare' (getMetricOrdering metric) a b
where metricCompare' TheHigherTheBetter a b = a `compare` b
metricCompare' TheLowerTheBetter a b = b `compare` a
bestPossibleValue :: Metric -> MetricValue
bestPossibleValue metric = case getMetricOrdering metric of
TheLowerTheBetter -> 0.0

View File

@ -95,6 +95,10 @@ optionsParser = GEvalOptions
(flag' ListMetrics
( long "list-metrics"
<> help "List all metrics with their descriptions"))
<|>
(flag' OracleItemBased
( long "oracle-item-based"
<> help "Generate the best possible output considering outputs given by --out-file and --alt-out-file options (and peeking into the expected file)."))
)
<*> ((flag' FirstTheWorst
@ -152,6 +156,10 @@ specParser = GEvalSpecification
<> showDefault
<> metavar "OUT"
<> help "The name of the file to be evaluated" )
<*> (optional $ some $ strOption
( long "alt-out-file"
<> metavar "OUT"
<> help "Alternative output file, makes sense only for some options, e.g. --oracle-item-based"))
<*> strOption
( long "expected-file"
<> short 'e'
@ -355,6 +363,9 @@ runGEval''' (Just Validate) _ _ spec _ _ = do
runGEval''' (Just ListMetrics) _ _ _ _ _ = do
listMetrics
return Nothing
runGEval''' (Just OracleItemBased) _ _ spec _ _ = do
runOracleItemBased spec
return Nothing
getGraphFilename :: Int -> FilePath -> FilePath
getGraphFilename 0 fp = fp

View File

@ -445,6 +445,7 @@ main = hspec $ do
gesTestName = "test-A",
gesSelector = Nothing,
gesOutFile = "out.tsv",
gesAltOutFiles = Nothing,
gesExpectedFile = "expected.tsv",
gesInputFile = "in.tsv",
gesMetrics = [EvaluationScheme Likelihood []],

View File

@ -0,0 +1,4 @@
A
C
D
D
1 A
2 C
3 D
4 D

View File

@ -0,0 +1,4 @@
D
C
B
A
1 D
2 C
3 B
4 A

View File

@ -0,0 +1,4 @@
B
A
C
A
1 B
2 A
3 C
4 A

View File

@ -0,0 +1 @@
--metric Accuracy

View File

@ -0,0 +1,4 @@
A
B
C
D
1 A
2 B
3 C
4 D

View File

@ -0,0 +1,4 @@
1
2
3
4
1 1
2 2
3 3
4 4

View File

@ -0,0 +1,4 @@
A
C
C
D
1 A
2 C
3 C
4 D