diff --git a/src/Data/Statistics/Calibration.hs b/src/Data/Statistics/Calibration.hs index 76b353e..3736b73 100644 --- a/src/Data/Statistics/Calibration.hs +++ b/src/Data/Statistics/Calibration.hs @@ -1,9 +1,9 @@ module Data.Statistics.Calibration (calibration, softCalibration) where -import Data.Statistics.Loess(loess) +import Data.Statistics.Loess (clippedLoess) import Numeric.Integration.TanhSinh -import Data.List(minimum, maximum) +import Data.List (minimum, maximum) import qualified Data.Vector.Unboxed as DVU minBand :: Double @@ -34,7 +34,7 @@ softCalibration [] _ = error "too few booleans in calibration" softCalibration _ [] = error "too few probabilities in calibration" softCalibration results probs | band probs < minBand = handleNarrowBand results probs - | otherwise = 1.0 - (min 1.0 (2.0 * (integrate (lowest, highest) (\x -> abs ((loess (DVU.fromList probs) (DVU.fromList results) x) - x))) / (highest - lowest))) + | otherwise = 1.0 - (min 1.0 (2.0 * (integrate (lowest, highest) (\x -> abs ((clippedLoess (DVU.fromList probs) (DVU.fromList results) x) - x))) / (highest - lowest))) where lowest = (minimum probs) + epsilon -- integrating loess gets crazy at edges highest = (maximum probs) - epsilon epsilon = 0.0001 diff --git a/src/Data/Statistics/Loess.hs b/src/Data/Statistics/Loess.hs index 1700179..39ea01a 100644 --- a/src/Data/Statistics/Loess.hs +++ b/src/Data/Statistics/Loess.hs @@ -1,23 +1,32 @@ module Data.Statistics.Loess - (loess) where + (loess, clippedLoess) where import qualified Statistics.Matrix.Types as SMT import Statistics.Regression (ols) import Data.Vector.Unboxed((!), zipWith, length, (++), map) import Statistics.Matrix(transpose) +import Statistics.Distribution.Normal (standard) +import Statistics.Distribution (density) + lambda :: Double -lambda = 2.0 +lambda = 8.0 triCube :: Double -> Double triCube d = (1.0 - (abs d) ** 3) ** 3 +gaussian :: Double -> Double +gaussian = density standard + +clippedLoess :: SMT.Vector -> SMT.Vector -> Double -> Double +clippedLoess inputs outputs x = min 1.0 $ max 0.0 $ loess inputs outputs x + loess :: SMT.Vector -> SMT.Vector -> Double -> Double loess inputs outputs x = a * x + b where a = params ! 1 b = params ! 0 params = ols inputMatrix scaledOutputs - weights = Data.Vector.Unboxed.map (\v -> lambda * triCube (lambda * (x - v))) inputs + weights = Data.Vector.Unboxed.map (\v -> lambda * gaussian (lambda * (x - v))) inputs scaledOutputs = Data.Vector.Unboxed.zipWith (*) outputs weights scaledInputs = Data.Vector.Unboxed.zipWith (*) inputs weights inputMatrix = transpose (SMT.Matrix 2 (Data.Vector.Unboxed.length inputs) 1000 (weights Data.Vector.Unboxed.++ scaledInputs)) diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index ab1d255..e60a004 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -102,8 +102,8 @@ import qualified Data.Vector.Unboxed as DVU import Statistics.Correlation -import Data.Statistics.Calibration(softCalibration) -import Data.Statistics.Loess(loess) +import Data.Statistics.Calibration (softCalibration) +import Data.Statistics.Loess (clippedLoess) import Data.Proxy @@ -755,7 +755,7 @@ gevalCore' (ProbabilisticSoftFMeasure beta) _ = gevalCoreWithoutInput parseAnnot probabilisticSoftAgg = CC.foldl probabilisticSoftFolder ([], [], fromInteger 0, 0) probabilisticSoftFolder (r1, p1, g1, e1) (r2, p2, g2, e2) = (r1 ++ r2, p1 ++ p2, g1 + g2, e1 + e2) loessGraph :: ([Double], [Double], Double, Int) -> Maybe GraphSeries - loessGraph (results, probs, _, _) = Just $ GraphSeries $ Prelude.map (\x -> (x, loess probs' results' x)) $ Prelude.filter (\p -> p > lowest && p < highest) $ Prelude.map (\d -> 0.01 * (fromIntegral d)) [1..99] + loessGraph (results, probs, _, _) = Just $ GraphSeries $ Prelude.map (\x -> (x, clippedLoess probs' results' x)) $ Prelude.filter (\p -> p > lowest && p < highest) $ Prelude.map (\d -> 0.01 * (fromIntegral d)) [1..99] where results' = DVU.fromList results probs' = DVU.fromList probs lowest = Data.List.minimum probs diff --git a/src/GEval/OptionsParser.hs b/src/GEval/OptionsParser.hs index 746f490..c06c550 100644 --- a/src/GEval/OptionsParser.hs +++ b/src/GEval/OptionsParser.hs @@ -9,8 +9,6 @@ module GEval.OptionsParser precisionArgParser ) where -import Debug.Trace - import Paths_geval (version) import Data.Version (showVersion)