Speed up GLEU (cntd.)
This commit is contained in:
parent
997a6c02ab
commit
3a852ed081
@ -27,14 +27,24 @@ bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3
|
||||
|
||||
gleuStep :: Ord a => [[a]] -> [a] -> (Int, Int)
|
||||
gleuStep refs trans = maximumBy (\(g1, t1) (g2, t2) -> (g1 /. t1) `compare` (g2 /. t2)) $ map getBetterCounts refs
|
||||
where getBetterCounts ref = let (matched1, expected1, got1) = getCounts (==) (ref, trans1grams)
|
||||
(matched2, expected2, got2) = getCounts (==) (bigrams ref, trans2grams)
|
||||
(matched3, expected3, got3) = getCounts (==) (trigrams ref, trans3grams)
|
||||
(matched4, expected4, got4) = getCounts (==) (tetragrams ref, trans4grams)
|
||||
where getBetterCounts ref = let (matched1, expected1, got1) = getSimpleCounts ref trans1grams
|
||||
(matched2, expected2, got2) = getSimpleCounts (bigrams ref) trans2grams
|
||||
(matched3, expected3, got3) = getSimpleCounts (trigrams ref) trans3grams
|
||||
(matched4, expected4, got4) = getSimpleCounts (tetragrams ref) trans4grams
|
||||
total = max (expected1 + expected2 + expected3 + expected4) (got1 + got2 + got3 + got4)
|
||||
in (matched1 + matched2 + matched3 + matched4, total)
|
||||
(trans1grams, trans2grams, trans3grams, trans4grams) = upToTetragrams trans
|
||||
|
||||
getSimpleCounts :: Ord a => [a] -> [a] -> (Int, Int, Int)
|
||||
getSimpleCounts expected got = (countMatched expected got,
|
||||
length expected,
|
||||
length got)
|
||||
|
||||
countMatched :: Ord a => [a] -> [a] -> Int
|
||||
countMatched ref got = sum $ map (lookFor rf) $ MS.toOccurList $ MS.fromList got
|
||||
where rf = MS.fromList ref
|
||||
lookFor rf (e, freq) = min freq (MS.occur e rf)
|
||||
|
||||
precisionCount :: Ord a => [[a]] -> [a] -> Int
|
||||
precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList
|
||||
where lookFor refs (e, freq) = minimumOrZero $ filter (> 0) $ map (findE e freq) $ map MS.fromList refs
|
||||
|
@ -58,8 +58,8 @@ countFolder (a1, a2, a3) (b1, b2, b3) = (a1+b1, a2+b2, a3+b3)
|
||||
|
||||
getCounts :: (a -> b -> Bool) -> ([a], [b]) -> (Int, Int, Int)
|
||||
getCounts matchingFun (expected, got) = (maxMatch matchingFun expected got,
|
||||
length expected,
|
||||
length got)
|
||||
length expected,
|
||||
length got)
|
||||
|
||||
-- | Calculates both precision and recall.
|
||||
--
|
||||
|
Loading…
Reference in New Issue
Block a user