From 7b009b048a88eff24920d9781bf2f89d9c2b5bc1 Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Mon, 24 Aug 2015 23:40:40 +0200 Subject: [PATCH] BLEU cntd. --- geval.cabal | 1 + src/GEval/BLEU.hs | 2 +- src/GEval/Core.hs | 22 ++++++++++++++++++---- test/Spec.hs | 6 ++++++ 4 files changed, 26 insertions(+), 5 deletions(-) diff --git a/geval.cabal b/geval.cabal index b23c12e..2a521f5 100644 --- a/geval.cabal +++ b/geval.cabal @@ -30,6 +30,7 @@ library , multiset , optparse-applicative , resourcet + , split , text default-language: Haskell2010 diff --git a/src/GEval/BLEU.hs b/src/GEval/BLEU.hs index 0fa3bbd..6179d90 100644 --- a/src/GEval/BLEU.hs +++ b/src/GEval/BLEU.hs @@ -1,5 +1,5 @@ module GEval.BLEU - (precisionCount) + (precisionCount, bleuStep) where import qualified Data.MultiSet as MS diff --git a/src/GEval/Core.hs b/src/GEval/Core.hs index 8adbc2a..78000fe 100644 --- a/src/GEval/Core.hs +++ b/src/GEval/Core.hs @@ -32,6 +32,10 @@ import qualified System.Directory as D import System.FilePath import Data.Maybe +import qualified Data.List.Split as DLS + +import GEval.BLEU + type MetricValue = Double data Metric = RMSE | MSE | BLEU @@ -124,17 +128,27 @@ gevalCore metric expectedFilePath outFilePath = do gevalCore' metric expectedFilePath outFilePath gevalCore' :: Metric -> String -> String -> IO (MetricValue) -gevalCore' MSE = gevalCore'' outParser outParser itemError averageC +gevalCore' MSE = gevalCore'' outParser outParser itemError averageC id where outParser = getValue . TR.double -gevalCore'' :: (Text -> a) -> (Text -> b) -> ((a, b) -> c) -> (Sink c (ResourceT IO) Double) -> String -> String -> IO (MetricValue) -gevalCore'' expParser outParser itemStep aggregator expectedFilePath outFilePath = - runResourceT $ +gevalCore' BLEU = gevalCore'' (DLS.splitOn "\t" . unpack) unpack bleuCombine bleuAgg bleuFinal + where bleuFinal (p1, p2, p3, p4, cl, l1, l2, l3, l4) = p1 /. l1 + bleuCombine (refs, sen) = bleuStep refs sen + bleuAgg = CC.foldl bleuFuse (0, 0, 0, 0, 0, 0, 0, 0, 0) + bleuFuse (a1, a2, a3, a4, a5, a6, a7, a8, a9) (b1, b2, b3, b4, b5, b6, b7, b8, b9) = (a1+b1, a2+b2, a3+b3, a4+b4, a5+b5, a6+b6, a7+b7, a8+b8, a9+b9) + +(/.) :: Int -> Int -> Double +x /. y = (fromIntegral x) / (fromIntegral y) + +gevalCore'' :: (Text -> a) -> (Text -> b) -> ((a, b) -> c) -> (Sink c (ResourceT IO) d) -> (d -> Double ) -> String -> String -> IO (MetricValue) +gevalCore'' expParser outParser itemStep aggregator finalStep expectedFilePath outFilePath = do + v <- runResourceT $ (getZipSource $ (,) <$> ZipSource (items expectedFilePath expParser) <*> ZipSource (items outFilePath outParser)) $$ (CL.map itemStep =$ aggregator) + return $ finalStep v averageC :: MonadResource m => Sink Double m Double averageC = getZipSink diff --git a/test/Spec.hs b/test/Spec.hs index e12b9e3..0a69ce4 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -17,6 +17,12 @@ main = hspec $ do "test/mse-simple/mse-simple", "--out-directory", "test/mse-simple/mse-simple-solution"]) >>= extractVal) `shouldReturnAlmost` 0.4166666666666667 + describe "BLEU" $ do + it "trivial example from Wikipedia" $ do + ((runGEval ["--expected-directory", + "test/bleu-trivial/bleu-trivial", + "--out-directory", + "test/bleu-trivial/bleu-trivial-solution"]) >>= extractVal) `shouldReturnAlmost` 0.0 describe "precision count" $ do it "simple test" $ do precisionCount [["Alice", "has", "a", "cat" ]] ["Ala", "has", "cat"] `shouldBe` 2