diff --git a/geval.cabal b/geval.cabal index 51e9b37..b23c12e 100644 --- a/geval.cabal +++ b/geval.cabal @@ -18,6 +18,7 @@ library exposed-modules: GEval.Core GEval.CreateChallenge , GEval.OptionsParser + , GEval.BLEU build-depends: base >= 4.7 && < 5 , cond , conduit @@ -26,6 +27,7 @@ library , directory , filepath , here + , multiset , optparse-applicative , resourcet , text diff --git a/src/GEval/BLEU.hs b/src/GEval/BLEU.hs new file mode 100644 index 0000000..0fa3bbd --- /dev/null +++ b/src/GEval/BLEU.hs @@ -0,0 +1,42 @@ +module GEval.BLEU + (precisionCount) + where + +import qualified Data.MultiSet as MS +import Data.List (minimumBy, zip, zip3, zip4) + +bleuStep :: Ord a => [[a]] -> [a] -> (Int, Int, Int, Int, Int, Int, Int, Int, Int) +bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3, len4) + where prec1 = precisionCountForNgrams id + prec2 = precisionCountForNgrams bigrams + prec3 = precisionCountForNgrams trigrams + prec4 = precisionCountForNgrams tetragrams + precisionCountForNgrams fun = precisionCount (map fun refs) (fun trans) + closestLen = minimumBy closestCmp $ map length refs + closestCmp x y + | ((abs (x - len1)) < (abs (y - len1))) = LT + | ((abs (x - len1)) > (abs (y - len1))) = GT + | ((abs (x - len1)) == (abs (y - len1))) = x `compare` y + len1 = length trans + len2 = max 0 (len1 - 1) + len3 = max 0 (len1 - 2) + len4 = max 0 (len1 - 3) + bigrams [] = [] + bigrams [_] = [] + bigrams u = zip u $ tail u + trigrams [] = [] + trigrams [_] = [] + trigrams [_, _] = [] + trigrams u = zip3 u (tail u) (tail $ tail u) + tetragrams [] = [] + tetragrams [_] = [] + tetragrams [_, _] = [] + tetragrams [_, _, _] = [] + tetragrams u = zip4 u (tail u) (tail $ tail u) (tail $ tail $ tail u) + +precisionCount :: Ord a => [[a]] -> [a] -> Int +precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList + where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs + findE e freq m = min freq (MS.occur e m) + minimumOrZero [] = 0 + minimumOrZero l = minimum l diff --git a/test/Spec.hs b/test/Spec.hs index 5181e59..e12b9e3 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -2,6 +2,7 @@ import Test.Hspec import GEval.Core import GEval.OptionsParser +import GEval.BLEU import Options.Applicative import qualified Test.HUnit as HU @@ -16,6 +17,15 @@ main = hspec $ do "test/mse-simple/mse-simple", "--out-directory", "test/mse-simple/mse-simple-solution"]) >>= extractVal) `shouldReturnAlmost` 0.4166666666666667 + describe "precision count" $ do + it "simple test" $ do + precisionCount [["Alice", "has", "a", "cat" ]] ["Ala", "has", "cat"] `shouldBe` 2 + it "none found" $ do + precisionCount [["Alice", "has", "a", "cat" ]] ["for", "bar", "baz"] `shouldBe` 0 + it "multiple values" $ do + precisionCount [["bar", "bar", "bar", "bar", "foo", "xyz", "foo"]] ["foo", "bar", "foo", "baz", "bar", "foo"] `shouldBe` 4 + it "multiple refs" $ do + precisionCount [["foo", "baz"], ["bar"], ["baz", "xyz"]] ["foo", "bar", "foo"] `shouldBe` 2 extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue extractVal (Right (Just val)) = return val