Speed up GLEU

This commit is contained in:
Filip Gralinski 2018-09-25 07:10:17 +02:00
parent 680bc80f40
commit 4f09a1802f

View File

@ -1,3 +1,5 @@
{-# LANGUAGE BangPatterns #-}
module GEval.BLEU
(precisionCount, bleuStep, gleuStep)
where
@ -27,10 +29,13 @@ bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3
gleuStep :: Ord a => [[a]] -> [a] -> (Int, Int)
gleuStep refs trans = maximumBy (\(g1, t1) (g2, t2) -> (g1 /. t1) `compare` (g2 /. t2)) $ map getBetterCounts refs
where getBetterCounts ref = let (matched, expected, got) = getCounts (==) (upToTetragrams ref, transNgrams)
total = max expected got
in (matched, total)
transNgrams = upToTetragrams trans
where getBetterCounts ref = let (matched1, expected1, got1) = getCounts (==) (ref, trans1grams)
(matched2, expected2, got2) = getCounts (==) (bigrams ref, trans2grams)
(matched3, expected3, got3) = getCounts (==) (trigrams ref, trans3grams)
(matched4, expected4, got4) = getCounts (==) (tetragrams ref, trans4grams)
total = max (expected1 + expected2 + expected3 + expected4) (got1 + got2 + got3 + got4)
in (matched1 + matched2 + matched3 + matched4, total)
(trans1grams, trans2grams, trans3grams, trans4grams) = upToTetragrams trans
precisionCount :: Ord a => [[a]] -> [a] -> Int
precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
@ -39,14 +44,8 @@ precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
minimumOrZero [] = 0
minimumOrZero l = minimum l
data Ngram a = Unigram a | Bigram (a, a) | Trigram (a, a, a) | Tetragram (a, a, a, a)
deriving (Eq, Show)
upToTetragrams :: [a] -> [Ngram a]
upToTetragrams l = (map Unigram l)
++ (map Bigram $ bigrams l)
++ (map Trigram $ trigrams l)
++ (map Tetragram $ tetragrams l)
upToTetragrams :: [a] -> ([a], [(a, a)], [(a, a, a)], [(a, a, a, a)])
upToTetragrams l = (l, bigrams l, trigrams l, tetragrams l)
bigrams :: [a] -> [(a, a)]
bigrams [] = []