implement mean absolute error

This commit is contained in:
Filip Gralinski 2018-06-13 12:30:11 +02:00
parent 1073407760
commit 012578f32a
5 changed files with 28 additions and 4 deletions
src/GEval
test
Spec.hs
mae-simple
mae-simple-solution/test-A
mae-simple

View File

@ -5,6 +5,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module GEval.Core
( geval,
gevalCore,
@ -88,7 +89,7 @@ defaultLogLossHashedSize = 10
-- | evaluation metric
data Metric = RMSE | MSE | BLEU | Accuracy | ClippEU | FMeasure Double | NMI
| LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood
| BIOF1 | BIOF1Labels | LikelihoodHashed Word32
| BIOF1 | BIOF1Labels | LikelihoodHashed Word32 | MAE
deriving (Eq)
instance Show Metric where
@ -117,6 +118,9 @@ instance Show Metric where
show Likelihood = "Likelihood"
show BIOF1 = "BIO-F1"
show BIOF1Labels = "BIO-F1-Labels"
show MAE = "MAE"
instance Read Metric where
readsPrec _ ('R':'M':'S':'E':theRest) = [(RMSE, theRest)]
@ -140,6 +144,7 @@ instance Read Metric where
readsPrec _ ('M':'A':'P':theRest) = [(MAP, theRest)]
readsPrec _ ('B':'I':'O':'-':'F':'1':'-':'L':'a':'b':'e':'l':'s':theRest) = [(BIOF1Labels, theRest)]
readsPrec _ ('B':'I':'O':'-':'F':'1':theRest) = [(BIOF1, theRest)]
readsPrec _ ('M':'A':'E':theRest) = [(MAE, theRest)]
data MetricOrdering = TheLowerTheBetter | TheHigherTheBetter
@ -160,6 +165,7 @@ getMetricOrdering LogLoss = TheLowerTheBetter
getMetricOrdering Likelihood = TheHigherTheBetter
getMetricOrdering BIOF1 = TheHigherTheBetter
getMetricOrdering BIOF1Labels = TheHigherTheBetter
getMetricOrdering MAE = TheLowerTheBetter
isInputNeeded :: Metric -> Bool
isInputNeeded CharMatch = True
@ -403,9 +409,13 @@ gevalCore' :: (MonadIO m, MonadThrow m, MonadUnliftIO m) =>
-> LineSource (ResourceT m) -- ^ source to read the expected output
-> LineSource (ResourceT m) -- ^ source to read the output
-> m (MetricValue) -- ^ metric values for the output against the expected output
gevalCore' MSE _ = gevalCoreWithoutInput outParser outParser itemError averageC id
gevalCore' MSE _ = gevalCoreWithoutInput outParser outParser itemSquaredError averageC id
where outParser = getValue . TR.double
gevalCore' MAE _ = gevalCoreWithoutInput outParser outParser itemAbsoluteError averageC id
where outParser = getValue . TR.double
gevalCore' LogLoss _ = gevalCoreWithoutInput outParser outParser itemLogLossError averageC id
where outParser = getValue . TR.double
@ -660,8 +670,12 @@ items (LineSource lineSource _ _) parser =
where toItem (Right x) = Got x
toItem (Left m) = Wrong m
itemError :: (Double, Double) -> Double
itemError (exp, out) = (exp-out)**2
itemAbsoluteError :: (Double, Double) -> Double
itemAbsoluteError (exp, out) = abs (exp-out)
itemSquaredError :: (Double, Double) -> Double
itemSquaredError (exp, out) = (exp-out)**2
itemLogLossError :: (Double, Double) -> Double
itemLogLossError (exp, out)

View File

@ -51,6 +51,9 @@ main = hspec $ do
describe "mean square error" $ do
it "simple test with arguments" $
runGEvalTest "mse-simple" `shouldReturnAlmost` 0.4166666666666667
describe "mean absolute error" $ do
it "simple test with arguments" $
runGEvalTest "mae-simple" `shouldReturnAlmost` 1.5
describe "BLEU" $ do
it "trivial example from Wikipedia" $
runGEvalTest "bleu-trivial" `shouldReturnAlmost` 0.0

View File

@ -0,0 +1,3 @@
34.0
26.0
18.5
1 34.0
2 26.0
3 18.5

View File

@ -0,0 +1 @@
--metric MAE

View File

@ -0,0 +1,3 @@
30.0
26.5
18.5
1 30.0
2 26.5
3 18.5