Add Pearson and Spearman correlation measures

This commit is contained in:
Filip Gralinski 2018-09-27 21:52:02 +02:00
parent b3800bc1d9
commit 5dc6e13191
6 changed files with 40 additions and 2 deletions

View File

@ -85,6 +85,10 @@ import Data.Conduit.AutoDecompress
import Text.Tokenizer import Text.Tokenizer
import qualified Data.HashMap.Strict as M import qualified Data.HashMap.Strict as M
import qualified Data.Vector as V
import qualified Data.Vector.Generic as VG
import Statistics.Correlation
import Data.Proxy import Data.Proxy
@ -98,7 +102,8 @@ defaultLogLossHashedSize :: Word32
defaultLogLossHashedSize = 10 defaultLogLossHashedSize = 10
-- | evaluation metric -- | evaluation metric
data Metric = RMSE | MSE | BLEU | GLEU | WER | Accuracy | ClippEU | FMeasure Double | MacroFMeasure Double | NMI data Metric = RMSE | MSE | Pearson | Spearman | BLEU | GLEU | WER | Accuracy | ClippEU
| FMeasure Double | MacroFMeasure Double | NMI
| LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood | LogLossHashed Word32 | CharMatch | MAP | LogLoss | Likelihood
| BIOF1 | BIOF1Labels | LikelihoodHashed Word32 | MAE | MultiLabelFMeasure Double | BIOF1 | BIOF1Labels | LikelihoodHashed Word32 | MAE | MultiLabelFMeasure Double
| MultiLabelLogLoss | MultiLabelLikelihood | MultiLabelLogLoss | MultiLabelLikelihood
@ -107,6 +112,8 @@ data Metric = RMSE | MSE | BLEU | GLEU | WER | Accuracy | ClippEU | FMeasure Dou
instance Show Metric where instance Show Metric where
show RMSE = "RMSE" show RMSE = "RMSE"
show MSE = "MSE" show MSE = "MSE"
show Pearson = "Pearson"
show Spearman = "Spearman"
show BLEU = "BLEU" show BLEU = "BLEU"
show GLEU = "GLEU" show GLEU = "GLEU"
show WER = "WER" show WER = "WER"
@ -141,6 +148,8 @@ instance Show Metric where
instance Read Metric where instance Read Metric where
readsPrec _ ('R':'M':'S':'E':theRest) = [(RMSE, theRest)] readsPrec _ ('R':'M':'S':'E':theRest) = [(RMSE, theRest)]
readsPrec _ ('M':'S':'E':theRest) = [(MSE, theRest)] readsPrec _ ('M':'S':'E':theRest) = [(MSE, theRest)]
readsPrec _ ('P':'e':'a':'r':'s':'o':'n':theRest) = [(Pearson, theRest)]
readsPrec _ ('S':'p':'e':'a':'r':'m':'a':'n':theRest) = [(Spearman, theRest)]
readsPrec _ ('B':'L':'E':'U':theRest) = [(BLEU, theRest)] readsPrec _ ('B':'L':'E':'U':theRest) = [(BLEU, theRest)]
readsPrec _ ('G':'L':'E':'U':theRest) = [(GLEU, theRest)] readsPrec _ ('G':'L':'E':'U':theRest) = [(GLEU, theRest)]
readsPrec _ ('W':'E':'R':theRest) = [(WER, theRest)] readsPrec _ ('W':'E':'R':theRest) = [(WER, theRest)]
@ -180,6 +189,8 @@ data MetricOrdering = TheLowerTheBetter | TheHigherTheBetter
getMetricOrdering :: Metric -> MetricOrdering getMetricOrdering :: Metric -> MetricOrdering
getMetricOrdering RMSE = TheLowerTheBetter getMetricOrdering RMSE = TheLowerTheBetter
getMetricOrdering MSE = TheLowerTheBetter getMetricOrdering MSE = TheLowerTheBetter
getMetricOrdering Pearson = TheHigherTheBetter
getMetricOrdering Spearman = TheHigherTheBetter
getMetricOrdering BLEU = TheHigherTheBetter getMetricOrdering BLEU = TheHigherTheBetter
getMetricOrdering GLEU = TheHigherTheBetter getMetricOrdering GLEU = TheHigherTheBetter
getMetricOrdering WER = TheLowerTheBetter getMetricOrdering WER = TheLowerTheBetter
@ -514,6 +525,8 @@ gevalCore' MSE _ = gevalCoreWithoutInput outParser outParser itemSquaredError av
gevalCore' MAE _ = gevalCoreWithoutInput outParser outParser itemAbsoluteError averageC id gevalCore' MAE _ = gevalCoreWithoutInput outParser outParser itemAbsoluteError averageC id
where outParser = getValue . TR.double where outParser = getValue . TR.double
gevalCore' Pearson _ = gevalCoreByCorrelationMeasure pearson
gevalCore' Spearman _ = gevalCoreByCorrelationMeasure spearman
gevalCore' LogLoss _ = gevalCoreWithoutInput outParser outParser itemLogLossError averageC id gevalCore' LogLoss _ = gevalCoreWithoutInput outParser outParser itemLogLossError averageC id
where outParser = getValue . TR.double where outParser = getValue . TR.double
@ -670,6 +683,17 @@ gevalCore' MultiLabelLogLoss _ = gevalCoreWithoutInput intoWords
countAgg :: Monad m => ConduitM (Int, Int, Int) o m (Int, Int, Int) countAgg :: Monad m => ConduitM (Int, Int, Int) o m (Int, Int, Int)
countAgg = CC.foldl countFolder (0, 0, 0) countAgg = CC.foldl countFolder (0, 0, 0)
gevalCoreByCorrelationMeasure :: (MonadUnliftIO m, MonadThrow m, MonadIO m) =>
(V.Vector (Double, Double) -> Double) -> -- ^ correlation function
LineSource (ResourceT m) -> -- ^ source to read the expected output
LineSource (ResourceT m) -> -- ^ source to read the output
m (MetricValue) -- ^ metric values for the output against the expected output
gevalCoreByCorrelationMeasure correlationFunction =
gevalCoreWithoutInput outParser outParser id correlationC finalStep
where outParser = getValue . TR.double
correlationC = CC.foldl (flip (:)) []
finalStep pairs = correlationFunction $ V.fromList pairs
parseDistributionWrapper :: Word32 -> Word32 -> Text -> HashedDistribution parseDistributionWrapper :: Word32 -> Word32 -> Text -> HashedDistribution
parseDistributionWrapper nbOfBits seed distroSpec = case parseDistribution nbOfBits seed distroSpec of parseDistributionWrapper nbOfBits seed distroSpec = case parseDistribution nbOfBits seed distroSpec of
Right distro -> distro Right distro -> distro

View File

@ -169,7 +169,7 @@ metricReader = many $ option auto -- actually `some` should be used inst
( long "metric" -- --metric might be in the config.txt file... ( long "metric" -- --metric might be in the config.txt file...
<> short 'm' <> short 'm'
<> metavar "METRIC" <> metavar "METRIC"
<> help "Metric to be used - RMSE, MSE, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels or CharMatch" ) <> help "Metric to be used - RMSE, MSE, Pearson, Spearman, Accuracy, LogLoss, Likelihood, F-measure (specify as F1, F2, F0.25, etc.), macro F-measure (specify as Macro-F1, Macro-F2, Macro-F0.25, etc.), multi-label F-measure (specify as MultiLabel-F1, MultiLabel-F2, MultiLabel-F0.25, etc.), MAP, BLEU, GLEU (\"Google GLEU\" not the grammar correction metric), WER, NMI, ClippEU, LogLossHashed, LikelihoodHashed, BIO-F1, BIO-F1-Labels or CharMatch" )
altMetricReader :: Parser (Maybe Metric) altMetricReader :: Parser (Maybe Metric)
altMetricReader = optional $ option auto altMetricReader = optional $ option auto

View File

@ -70,6 +70,9 @@ main = hspec $ do
describe "mean absolute error" $ do describe "mean absolute error" $ do
it "simple test with arguments" $ it "simple test with arguments" $
runGEvalTest "mae-simple" `shouldReturnAlmost` 1.5 runGEvalTest "mae-simple" `shouldReturnAlmost` 1.5
describe "Spearman's rank correlation coefficient" $ do
it "simple test" $ do
runGEvalTest "spearman-simple" `shouldReturnAlmost` (- 0.5735)
describe "BLEU" $ do describe "BLEU" $ do
it "trivial example from Wikipedia" $ it "trivial example from Wikipedia" $
runGEvalTest "bleu-trivial" `shouldReturnAlmost` 0.0 runGEvalTest "bleu-trivial" `shouldReturnAlmost` 0.0

View File

@ -0,0 +1,5 @@
1.2
1
2.3
1
18
1 1.2
2 1
3 2.3
4 1
5 18

View File

@ -0,0 +1 @@
--metric Spearman

View File

@ -0,0 +1,5 @@
1.1
1.57
0.51
1.1
1.1
1 1.1
2 1.57
3 0.51
4 1.1
5 1.1