start work on BLEU
This commit is contained in:
parent
1ebe413d82
commit
bf4b91f8f8
@ -18,6 +18,7 @@ library
|
||||
exposed-modules: GEval.Core
|
||||
GEval.CreateChallenge
|
||||
, GEval.OptionsParser
|
||||
, GEval.BLEU
|
||||
build-depends: base >= 4.7 && < 5
|
||||
, cond
|
||||
, conduit
|
||||
@ -26,6 +27,7 @@ library
|
||||
, directory
|
||||
, filepath
|
||||
, here
|
||||
, multiset
|
||||
, optparse-applicative
|
||||
, resourcet
|
||||
, text
|
||||
|
42
src/GEval/BLEU.hs
Normal file
42
src/GEval/BLEU.hs
Normal file
@ -0,0 +1,42 @@
|
||||
module GEval.BLEU
|
||||
(precisionCount)
|
||||
where
|
||||
|
||||
import qualified Data.MultiSet as MS
|
||||
import Data.List (minimumBy, zip, zip3, zip4)
|
||||
|
||||
bleuStep :: Ord a => [[a]] -> [a] -> (Int, Int, Int, Int, Int, Int, Int, Int, Int)
|
||||
bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3, len4)
|
||||
where prec1 = precisionCountForNgrams id
|
||||
prec2 = precisionCountForNgrams bigrams
|
||||
prec3 = precisionCountForNgrams trigrams
|
||||
prec4 = precisionCountForNgrams tetragrams
|
||||
precisionCountForNgrams fun = precisionCount (map fun refs) (fun trans)
|
||||
closestLen = minimumBy closestCmp $ map length refs
|
||||
closestCmp x y
|
||||
| ((abs (x - len1)) < (abs (y - len1))) = LT
|
||||
| ((abs (x - len1)) > (abs (y - len1))) = GT
|
||||
| ((abs (x - len1)) == (abs (y - len1))) = x `compare` y
|
||||
len1 = length trans
|
||||
len2 = max 0 (len1 - 1)
|
||||
len3 = max 0 (len1 - 2)
|
||||
len4 = max 0 (len1 - 3)
|
||||
bigrams [] = []
|
||||
bigrams [_] = []
|
||||
bigrams u = zip u $ tail u
|
||||
trigrams [] = []
|
||||
trigrams [_] = []
|
||||
trigrams [_, _] = []
|
||||
trigrams u = zip3 u (tail u) (tail $ tail u)
|
||||
tetragrams [] = []
|
||||
tetragrams [_] = []
|
||||
tetragrams [_, _] = []
|
||||
tetragrams [_, _, _] = []
|
||||
tetragrams u = zip4 u (tail u) (tail $ tail u) (tail $ tail $ tail u)
|
||||
|
||||
precisionCount :: Ord a => [[a]] -> [a] -> Int
|
||||
precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
|
||||
where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs
|
||||
findE e freq m = min freq (MS.occur e m)
|
||||
minimumOrZero [] = 0
|
||||
minimumOrZero l = minimum l
|
10
test/Spec.hs
10
test/Spec.hs
@ -2,6 +2,7 @@ import Test.Hspec
|
||||
|
||||
import GEval.Core
|
||||
import GEval.OptionsParser
|
||||
import GEval.BLEU
|
||||
import Options.Applicative
|
||||
import qualified Test.HUnit as HU
|
||||
|
||||
@ -16,6 +17,15 @@ main = hspec $ do
|
||||
"test/mse-simple/mse-simple",
|
||||
"--out-directory",
|
||||
"test/mse-simple/mse-simple-solution"]) >>= extractVal) `shouldReturnAlmost` 0.4166666666666667
|
||||
describe "precision count" $ do
|
||||
it "simple test" $ do
|
||||
precisionCount [["Alice", "has", "a", "cat" ]] ["Ala", "has", "cat"] `shouldBe` 2
|
||||
it "none found" $ do
|
||||
precisionCount [["Alice", "has", "a", "cat" ]] ["for", "bar", "baz"] `shouldBe` 0
|
||||
it "multiple values" $ do
|
||||
precisionCount [["bar", "bar", "bar", "bar", "foo", "xyz", "foo"]] ["foo", "bar", "foo", "baz", "bar", "foo"] `shouldBe` 4
|
||||
it "multiple refs" $ do
|
||||
precisionCount [["foo", "baz"], ["bar"], ["baz", "xyz"]] ["foo", "bar", "foo"] `shouldBe` 2
|
||||
|
||||
extractVal :: (Either (ParserResult GEvalOptions) (Maybe MetricValue)) -> IO MetricValue
|
||||
extractVal (Right (Just val)) = return val
|
||||
|
Loading…
Reference in New Issue
Block a user