From 3a852ed081997ef2149f91300176f02c06570c81 Mon Sep 17 00:00:00 2001 From: Filip Gralinski Date: Wed, 26 Sep 2018 22:27:59 +0200 Subject: [PATCH] Speed up GLEU (cntd.) --- src/GEval/BLEU.hs | 18 ++++++++++++++---- src/GEval/PrecisionRecall.hs | 4 ++-- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/GEval/BLEU.hs b/src/GEval/BLEU.hs index 3d7650f..b9afc1b 100644 --- a/src/GEval/BLEU.hs +++ b/src/GEval/BLEU.hs @@ -27,14 +27,24 @@ bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3 gleuStep :: Ord a => [[a]] -> [a] -> (Int, Int) gleuStep refs trans = maximumBy (\(g1, t1) (g2, t2) -> (g1 /. t1) `compare` (g2 /. t2)) $ map getBetterCounts refs - where getBetterCounts ref = let (matched1, expected1, got1) = getCounts (==) (ref, trans1grams) - (matched2, expected2, got2) = getCounts (==) (bigrams ref, trans2grams) - (matched3, expected3, got3) = getCounts (==) (trigrams ref, trans3grams) - (matched4, expected4, got4) = getCounts (==) (tetragrams ref, trans4grams) + where getBetterCounts ref = let (matched1, expected1, got1) = getSimpleCounts ref trans1grams + (matched2, expected2, got2) = getSimpleCounts (bigrams ref) trans2grams + (matched3, expected3, got3) = getSimpleCounts (trigrams ref) trans3grams + (matched4, expected4, got4) = getSimpleCounts (tetragrams ref) trans4grams total = max (expected1 + expected2 + expected3 + expected4) (got1 + got2 + got3 + got4) in (matched1 + matched2 + matched3 + matched4, total) (trans1grams, trans2grams, trans3grams, trans4grams) = upToTetragrams trans +getSimpleCounts :: Ord a => [a] -> [a] -> (Int, Int, Int) +getSimpleCounts expected got = (countMatched expected got, + length expected, + length got) + +countMatched :: Ord a => [a] -> [a] -> Int +countMatched ref got = sum $ map (lookFor rf) $ MS.toOccurList $ MS.fromList got + where rf = MS.fromList ref + lookFor rf (e, freq) = min freq (MS.occur e rf) + precisionCount :: Ord a => [[a]] -> [a] -> Int precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs diff --git a/src/GEval/PrecisionRecall.hs b/src/GEval/PrecisionRecall.hs index bcaad4b..ff38ca5 100644 --- a/src/GEval/PrecisionRecall.hs +++ b/src/GEval/PrecisionRecall.hs @@ -58,8 +58,8 @@ countFolder (a1, a2, a3) (b1, b2, b3) = (a1+b1, a2+b2, a3+b3) getCounts :: (a -> b -> Bool) -> ([a], [b]) -> (Int, Int, Int) getCounts matchingFun (expected, got) = (maxMatch matchingFun expected got, - length expected, - length got) + length expected, + length got) -- | Calculates both precision and recall. --