Speed up GLEU (cntd.)

This commit is contained in:
Filip Gralinski 2018-09-26 22:27:59 +02:00
parent 997a6c02ab
commit 3a852ed081
2 changed files with 16 additions and 6 deletions

View File

@ -27,14 +27,24 @@ bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3
gleuStep :: Ord a => [[a]] -> [a] -> (Int, Int)
gleuStep refs trans = maximumBy (\(g1, t1) (g2, t2) -> (g1 /. t1) `compare` (g2 /. t2)) $ map getBetterCounts refs
where getBetterCounts ref = let (matched1, expected1, got1) = getCounts (==) (ref, trans1grams)
(matched2, expected2, got2) = getCounts (==) (bigrams ref, trans2grams)
(matched3, expected3, got3) = getCounts (==) (trigrams ref, trans3grams)
(matched4, expected4, got4) = getCounts (==) (tetragrams ref, trans4grams)
where getBetterCounts ref = let (matched1, expected1, got1) = getSimpleCounts ref trans1grams
(matched2, expected2, got2) = getSimpleCounts (bigrams ref) trans2grams
(matched3, expected3, got3) = getSimpleCounts (trigrams ref) trans3grams
(matched4, expected4, got4) = getSimpleCounts (tetragrams ref) trans4grams
total = max (expected1 + expected2 + expected3 + expected4) (got1 + got2 + got3 + got4)
in (matched1 + matched2 + matched3 + matched4, total)
(trans1grams, trans2grams, trans3grams, trans4grams) = upToTetragrams trans
getSimpleCounts :: Ord a => [a] -> [a] -> (Int, Int, Int)
getSimpleCounts expected got = (countMatched expected got,
length expected,
length got)
countMatched :: Ord a => [a] -> [a] -> Int
countMatched ref got = sum $ map (lookFor rf) $ MS.toOccurList $ MS.fromList got
where rf = MS.fromList ref
lookFor rf (e, freq) = min freq (MS.occur e rf)
precisionCount :: Ord a => [[a]] -> [a] -> Int
precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs

View File

@ -58,8 +58,8 @@ countFolder (a1, a2, a3) (b1, b2, b3) = (a1+b1, a2+b2, a3+b3)
getCounts :: (a -> b -> Bool) -> ([a], [b]) -> (Int, Int, Int)
getCounts matchingFun (expected, got) = (maxMatch matchingFun expected got,
length expected,
length got)
length expected,
length got)
-- | Calculates both precision and recall.
--