start work on BLEU
This commit is contained in:
parent
1ebe413d82
commit
bf4b91f8f8
@ -18,6 +18,7 @@ library
|
|||||||
exposed-modules: GEval.Core
|
exposed-modules: GEval.Core
|
||||||
GEval.CreateChallenge
|
GEval.CreateChallenge
|
||||||
, GEval.OptionsParser
|
, GEval.OptionsParser
|
||||||
|
, GEval.BLEU
|
||||||
build-depends: base >= 4.7 && < 5
|
build-depends: base >= 4.7 && < 5
|
||||||
, cond
|
, cond
|
||||||
, conduit
|
, conduit
|
||||||
@ -26,6 +27,7 @@ library
|
|||||||
, directory
|
, directory
|
||||||
, filepath
|
, filepath
|
||||||
, here
|
, here
|
||||||
|
, multiset
|
||||||
, optparse-applicative
|
, optparse-applicative
|
||||||
, resourcet
|
, resourcet
|
||||||
, text
|
, text
|
||||||
|
42
src/GEval/BLEU.hs
Normal file
42
src/GEval/BLEU.hs
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
module GEval.BLEU
|
||||||
|
(precisionCount)
|
||||||
|
where
|
||||||
|
|
||||||
|
import qualified Data.MultiSet as MS
|
||||||
|
import Data.List (minimumBy, zip, zip3, zip4)
|
||||||
|
|
||||||
|
bleuStep :: Ord a => [[a]] -> [a] -> (Int, Int, Int, Int, Int, Int, Int, Int, Int)
|
||||||
|
bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3, len4)
|
||||||
|
where prec1 = precisionCountForNgrams id
|
||||||
|
prec2 = precisionCountForNgrams bigrams
|
||||||
|
prec3 = precisionCountForNgrams trigrams
|
||||||
|
prec4 = precisionCountForNgrams tetragrams
|
||||||
|
precisionCountForNgrams fun = precisionCount (map fun refs) (fun trans)
|
||||||
|
closestLen = minimumBy closestCmp $ map length refs
|
||||||
|
closestCmp x y
|
||||||
|
| ((abs (x - len1)) < (abs (y - len1))) = LT
|
||||||
|
| ((abs (x - len1)) > (abs (y - len1))) = GT
|
||||||
|
| ((abs (x - len1)) == (abs (y - len1))) = x `compare` y
|
||||||
|
len1 = length trans
|
||||||
|
len2 = max 0 (len1 - 1)
|
||||||
|
len3 = max 0 (len1 - 2)
|
||||||
|
len4 = max 0 (len1 - 3)
|
||||||
|
bigrams [] = []
|
||||||
|
bigrams [_] = []
|
||||||
|
bigrams u = zip u $ tail u
|
||||||
|
trigrams [] = []
|
||||||
|
trigrams [_] = []
|
||||||
|
trigrams [_, _] = []
|
||||||
|
trigrams u = zip3 u (tail u) (tail $ tail u)
|
||||||
|
tetragrams [] = []
|
||||||
|
tetragrams [_] = []
|
||||||
|
tetragrams [_, _] = []
|
||||||
|
tetragrams [_, _, _] = []
|
||||||
|
tetragrams u = zip4 u (tail u) (tail $ tail u) (tail $ tail $ tail u)
|
||||||
|
|
||||||
|
precisionCount :: Ord a => [[a]] -> [a] -> Int
|
||||||
|
precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
|
||||||
|
where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs
|
||||||
|
findE e freq m = min freq (MS.occur e m)
|
||||||
|
minimumOrZero [] = 0
|
||||||
|
minimumOrZero l = minimum l
|
10
test/Spec.hs
10
test/Spec.hs
@ -2,6 +2,7 @@ import Test.Hspec
|
|||||||
|
|
||||||
import GEval.Core
|
import GEval.Core
|
||||||
import GEval.OptionsParser
|
import GEval.OptionsParser
|
||||||
|
import GEval.BLEU
|
||||||
import Options.Applicative
|
import Options.Applicative
|
||||||
import qualified Test.HUnit as HU
|
import qualified Test.HUnit as HU
|
||||||
|
|
||||||
@ -16,6 +17,15 @@ main = hspec $ do
|
|||||||
"test/mse-simple/mse-simple",
|
"test/mse-simple/mse-simple",
|
||||||
"--out-directory",
|
"--out-directory",
|
||||||
"test/mse-simple/mse-simple-solution"]) >>= extractVal) `shouldReturnAlmost` 0.4166666666666667
|
"test/mse-simple/mse-simple-solution"]) >>= extractVal) `shouldReturnAlmost` 0.4166666666666667
|
||||||
|
describe "precision count" $ do
|
||||||
|
it "simple test" $ do
|
||||||
|
precisionCount [["Alice", "has", "a", "cat" ]] ["Ala", "has", "cat"] `shouldBe` 2
|
||||||
|
it "none found" $ do
|
||||||
|
precisionCount [["Alice", "has", "a", "cat" ]] ["for", "bar", "baz"] `shouldBe` 0
|
||||||
|
it "multiple values" $ do
|
||||||
|
precisionCount [["bar", "bar", "bar", "bar", "foo", "xyz", "foo"]] ["foo", "bar", "foo", "baz", "bar", "foo"] `shouldBe` 4
|
||||||
|
it "multiple refs" $ do
|
||||||
|
precisionCount [["foo", "baz"], ["bar"], ["baz", "xyz"]] ["foo", "bar", "foo"] `shouldBe` 2
|
||||||
|
|
||||||
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
||||||
extractVal (Right (Just val)) = return val
|
extractVal (Right (Just val)) = return val
|
||||||
|
Loading…
Reference in New Issue
Block a user