diff --git a/src/GEval/BLEU.hs b/src/GEval/BLEU.hs index ba3cb04..3bb0e42 100644 --- a/src/GEval/BLEU.hs +++ b/src/GEval/BLEU.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE BangPatterns #-} + module GEval.BLEU (precisionCount, bleuStep, gleuStep) where @@ -27,10 +29,13 @@ bleuStep refs trans = (prec1, prec2, prec3, prec4, closestLen, len1, len2, len3 gleuStep :: Ord a => [[a]] -> [a] -> (Int, Int) gleuStep refs trans = maximumBy (\(g1, t1) (g2, t2) -> (g1 /. t1) `compare` (g2 /. t2)) $ map getBetterCounts refs - where getBetterCounts ref = let (matched, expected, got) = getCounts (==) (upToTetragrams ref, transNgrams) - total = max expected got - in (matched, total) - transNgrams = upToTetragrams trans + where getBetterCounts ref = let (matched1, expected1, got1) = getCounts (==) (ref, trans1grams) + (matched2, expected2, got2) = getCounts (==) (bigrams ref, trans2grams) + (matched3, expected3, got3) = getCounts (==) (trigrams ref, trans3grams) + (matched4, expected4, got4) = getCounts (==) (tetragrams ref, trans4grams) + total = max (expected1 + expected2 + expected3 + expected4) (got1 + got2 + got3 + got4) + in (matched1 + matched2 + matched3 + matched4, total) + (trans1grams, trans2grams, trans3grams, trans4grams) = upToTetragrams trans precisionCount :: Ord a => [[a]] -> [a] -> Int precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList @@ -39,14 +44,8 @@ precisionCount refs = sum . map (lookFor refs) . MS.toOccurList . MS.fromList minimumOrZero [] = 0 minimumOrZero l = minimum l -data Ngram a = Unigram a | Bigram (a, a) | Trigram (a, a, a) | Tetragram (a, a, a, a) - deriving (Eq, Show) - -upToTetragrams :: [a] -> [Ngram a] -upToTetragrams l = (map Unigram l) - ++ (map Bigram $ bigrams l) - ++ (map Trigram $ trigrams l) - ++ (map Tetragram $ tetragrams l) +upToTetragrams :: [a] -> ([a], [(a, a)], [(a, a, a)], [(a, a, a, a)]) +upToTetragrams l = (l, bigrams l, trigrams l, tetragrams l) bigrams :: [a] -> [(a, a)] bigrams [] = []