diff --git a/geval.cabal b/geval.cabal index bd6bbb5..b77f8b8 100644 --- a/geval.cabal +++ b/geval.cabal @@ -31,6 +31,7 @@ library , Data.Conduit.AutoDecompress , Data.Conduit.SmartSource , Data.Conduit.Rank + , GEval.FeatureExtractor , Paths_geval build-depends: base >= 4.7 && < 5 , cond @@ -65,6 +66,7 @@ library , Glob , naturalcomp , containers + , statistics default-language: Haskell2010 executable geval diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 4fa49c4..6f6bef2 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -213,7 +213,7 @@ getExpectedDirectory :: GEvalSpecification -> FilePath getExpectedDirectory spec = fromMaybe outDirectory $ gesExpectedDirectory spec where outDirectory = gesOutDirectory spec -data GEvalSpecialCommand = Init | LineByLine | Diff FilePath | PrintVersion +data GEvalSpecialCommand = Init | LineByLine | WorstFeatures | Diff FilePath | PrintVersion data ResultOrdering = KeepTheOriginalOrder | FirstTheWorst | FirstTheBest diff --git a/src/GEval/FeatureExtractor.hs b/src/GEval/FeatureExtractor.hs new file mode 100644 index 0000000..e0f5f5b --- /dev/null +++ b/src/GEval/FeatureExtractor.hs @@ -0,0 +1,25 @@ +{-# LANGUAGE OverloadedStrings #-} + + +module GEval.FeatureExtractor + (extractUnigramFeatures, + extractUnigramFeaturesFromTabbed) + where + +import Data.Text +import Data.List +import Data.Monoid ((<>)) + +extractUnigramFeatures :: Text -> Text -> [Text] +extractUnigramFeatures namespace record = Prelude.map (prefix <>) $ nub $ tokenize record + where prefix = namespace <> ":" + +tokenize :: Text -> [Text] +tokenize t = Data.List.filter (not . Data.Text.null) $ split splitPred t + where splitPred c = c == ' ' || c == '\t' || c == ':' + +extractUnigramFeaturesFromTabbed :: Text -> Text -> [Text] +extractUnigramFeaturesFromTabbed namespace record = + Data.List.concat + $ Prelude.map (\(n, t) -> extractUnigramFeatures (namespace <> "<" <> (pack $ show n) <> ">") t) + $ Prelude.zip [1..] (splitOn "\t" record) diff --git a/src/GEval/LineByLine.hs b/src/GEval/LineByLine.hs index 3ce8697..80141ee 100644 --- a/src/GEval/LineByLine.hs +++ b/src/GEval/LineByLine.hs @@ -9,6 +9,7 @@ module GEval.LineByLine (runLineByLine, + runWorstFeatures, runLineByLineGeneralized, runDiff, runDiffGeneralized, @@ -25,12 +26,17 @@ import qualified Data.Conduit.List as CL import qualified Data.Conduit.Combinators as CC import Data.Text import Data.Text.Encoding +import Data.Conduit.Rank -import Data.List (sortBy, sort) +import Data.List (sortBy, sort, concat) import Control.Monad.IO.Class import Control.Monad.Trans.Resource +import Data.Monoid ((<>)) + +import GEval.FeatureExtractor + import Data.Word import Text.Printf @@ -39,6 +45,11 @@ import Data.Conduit.SmartSource import System.FilePath +import Statistics.Distribution (cumulative) +import Statistics.Distribution.Normal (normalDistr) + +import qualified Data.Map.Strict as M + data LineRecord = LineRecord Text Text Text Word32 MetricValue deriving (Eq, Show) @@ -54,6 +65,65 @@ runLineByLine ordering spec = runLineByLineGeneralized ordering spec consum formatScore :: MetricValue -> Text formatScore = Data.Text.pack . printf "%f" +runWorstFeatures :: ResultOrdering -> GEvalSpecification -> IO () +runWorstFeatures ordering spec = runLineByLineGeneralized ordering spec consum + where consum :: ConduitT LineRecord Void (ResourceT IO) () + consum = (rank (lessByMetric $ gesMainMetric spec) + .| featureExtractor + .| uScoresCounter + .| CL.map (encodeUtf8 . formatFeatureWithZScore) + .| CC.unlinesAscii + .| CC.stdout) + formatOutput (LineRecord inp exp out _ score) = Data.Text.intercalate "\t" [ + formatScore score, + escapeTabs inp, + escapeTabs exp, + escapeTabs out] + formatScore :: MetricValue -> Text + formatScore = Data.Text.pack . printf "%f" + +data RankedFeature = RankedFeature Text Double + deriving (Show) + +data FeatureWithZScore = FeatureWithZScore Text Double Int + deriving (Show) + +formatFeatureWithZScore :: FeatureWithZScore -> Text +formatFeatureWithZScore (FeatureWithZScore f z c) = + f <> " " <> (pack $ show c) <> " " <> (pack $ printf "%0.20f" z) + +featureExtractor :: Monad m => ConduitT (Double, LineRecord) RankedFeature m () +featureExtractor = CC.map extract .| CC.concat + where extract (rank, LineRecord inLine expLine outLine _ _) = + Prelude.map (\f -> RankedFeature f rank) + $ Data.List.concat [ + extractUnigramFeatures "exp" expLine, + extractUnigramFeaturesFromTabbed "in" inLine, + extractUnigramFeatures "out" outLine] + +uScoresCounter :: Monad m => ConduitT RankedFeature FeatureWithZScore m () +uScoresCounter = CC.map (\(RankedFeature feature r) -> (feature, (r, 1))) + .| gobbleAndDo countUScores + .| CC.map (\(f, (r, c)) -> FeatureWithZScore f (zscore (r - minusR c) c (2942 - c)) c) + where countUScores l = + M.toList + $ M.fromListWith (\(r1, c1) (r2, c2) -> ((r1 + r2), (c1 + c2))) l + minusR c = (c' * (c' + 1)) / 2.0 + where c' = fromIntegral c + zscore u n1 n2 = let n1' = fromIntegral n1 + n2' = fromIntegral n2 + mean = n1' * n2' / 2 + sigma = sqrt $ n1' * n2' * (n1' + n2' + 1) / 12 + z = (u - mean) / sigma + in cumulative (normalDistr 0.0 1.0) z + +lessByMetric :: Metric -> (LineRecord -> LineRecord -> Bool) +lessByMetric metric = lessByMetric' (getMetricOrdering metric) + where lessByMetric' TheHigherTheBetter = (\(LineRecord _ _ _ _ scoreA) (LineRecord _ _ _ _ scoreB) -> + scoreA < scoreB) + lessByMetric' TheLowerTheBetter = (\(LineRecord _ _ _ _ scoreA) (LineRecord _ _ _ _ scoreB) -> + scoreA > scoreB) + runLineByLineGeneralized :: ResultOrdering -> GEvalSpecification -> ConduitT LineRecord Void (ResourceT IO) a -> IO a runLineByLineGeneralized ordering spec consum = do (inputFilePath, expectedFilePath, outFilePath) <- checkAndGetFilesSingleOut True spec diff --git a/src/GEval/OptionsParser.hs b/src/GEval/OptionsParser.hs index 6647612..2ec66c0 100644 --- a/src/GEval/OptionsParser.hs +++ b/src/GEval/OptionsParser.hs @@ -46,6 +46,11 @@ optionsParser = GEvalOptions <> short 'l' <> help "Give scores for each line rather than the whole test set" )) <|> + (flag' WorstFeatures + ( long "worst-features" + <> short 'w' + <> help "Print a ranking of worst features, i.e. features that worsen the score significantly" )) + <|> (Diff <$> strOption ( long "diff" <> short 'd' @@ -194,6 +199,9 @@ runGEval''' (Just PrintVersion) _ _ = do runGEval''' (Just LineByLine) ordering spec = do runLineByLine ordering spec return Nothing +runGEval''' (Just WorstFeatures) ordering spec = do + runWorstFeatures ordering spec + return Nothing runGEval''' (Just (Diff otherOut)) ordering spec = do runDiff ordering otherOut spec return Nothing