refactor
This commit is contained in:
parent
c3a6d94d1c
commit
d118e15353
@ -21,6 +21,7 @@ library
|
|||||||
, GEval.BLEU
|
, GEval.BLEU
|
||||||
, GEval.ClippEU
|
, GEval.ClippEU
|
||||||
, GEval.PrecisionRecall
|
, GEval.PrecisionRecall
|
||||||
|
, GEval.Common
|
||||||
build-depends: base >= 4.7 && < 5
|
build-depends: base >= 4.7 && < 5
|
||||||
, cond
|
, cond
|
||||||
, conduit
|
, conduit
|
||||||
|
10
src/GEval/Common.hs
Normal file
10
src/GEval/Common.hs
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module GEval.Common
|
||||||
|
where
|
||||||
|
|
||||||
|
(/.) :: (Eq a, Integral a) => a -> a -> Double
|
||||||
|
x /. 0 = 0.0
|
||||||
|
x /. y = (fromIntegral x) / (fromIntegral y)
|
||||||
|
|
||||||
|
safeDoubleDiv :: Double -> Double -> Double
|
||||||
|
safeDoubleDiv _ 0.0 = 0.0
|
||||||
|
safeDoubleDiv x y = x / y
|
@ -36,6 +36,7 @@ import Data.Maybe
|
|||||||
import qualified Data.List.Split as DLS
|
import qualified Data.List.Split as DLS
|
||||||
|
|
||||||
import GEval.BLEU
|
import GEval.BLEU
|
||||||
|
import GEval.Common
|
||||||
|
|
||||||
type MetricValue = Double
|
type MetricValue = Double
|
||||||
|
|
||||||
@ -167,10 +168,6 @@ gevalCore' BLEU = gevalCore'' (Prelude.map Prelude.words . DLS.splitOn "\t" . un
|
|||||||
gevalCore' Accuracy = gevalCore'' strip strip hitOrMiss averageC id
|
gevalCore' Accuracy = gevalCore'' strip strip hitOrMiss averageC id
|
||||||
where hitOrMiss (x,y) = if x == y then 1.0 else 0.0
|
where hitOrMiss (x,y) = if x == y then 1.0 else 0.0
|
||||||
|
|
||||||
(/.) :: Int -> Int -> Double
|
|
||||||
x /. 0 = 1.0
|
|
||||||
x /. y = (fromIntegral x) / (fromIntegral y)
|
|
||||||
|
|
||||||
data SourceItem a = Got a | Done
|
data SourceItem a = Got a | Done
|
||||||
|
|
||||||
gevalCore'' :: (Text -> a) -> (Text -> b) -> ((a, b) -> c) -> (Sink c (ResourceT IO) d) -> (d -> Double) -> String -> String -> IO (MetricValue)
|
gevalCore'' :: (Text -> a) -> (Text -> b) -> ((a, b) -> c) -> (Sink c (ResourceT IO) d) -> (d -> Double) -> String -> String -> IO (MetricValue)
|
||||||
|
@ -5,6 +5,8 @@ module GEval.PrecisionRecall(fMeasure, f1Measure, f2Measure, precision, recall,
|
|||||||
precisionAndRecall, precisionAndRecallFromCounts)
|
precisionAndRecall, precisionAndRecallFromCounts)
|
||||||
where
|
where
|
||||||
|
|
||||||
|
import GEval.Common
|
||||||
|
|
||||||
import Data.Graph.Inductive
|
import Data.Graph.Inductive
|
||||||
import Data.Graph.Inductive.Query.MaxFlow
|
import Data.Graph.Inductive.Query.MaxFlow
|
||||||
|
|
||||||
@ -39,18 +41,7 @@ precisionAndRecall matchFun expected got
|
|||||||
|
|
||||||
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
|
precisionAndRecallFromCounts :: (Int, Int, Int) -> (Double, Double)
|
||||||
precisionAndRecallFromCounts (tp, nbExpected, nbGot) =
|
precisionAndRecallFromCounts (tp, nbExpected, nbGot) =
|
||||||
(tp `safeDiv` nbGot, tp `safeDiv` nbExpected)
|
(tp /. nbGot, tp /. nbExpected)
|
||||||
|
|
||||||
-- division with possible zero
|
|
||||||
safeDiv :: Int -> Int -> Double
|
|
||||||
safeDiv _ 0 = 1.0
|
|
||||||
safeDiv x y = (fromIntegral x) / (fromIntegral y)
|
|
||||||
|
|
||||||
safeDoubleDiv :: Double -> Double -> Double
|
|
||||||
safeDoubleDiv _ 0.0 = 0.0
|
|
||||||
safeDoubleDiv x y = x / y
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
precision :: (a -> b -> Bool) -> [a] -> [b] -> Double
|
||||||
precision matchFun expected got = fst $ precisionAndRecall matchFun expected got
|
precision matchFun expected got = fst $ precisionAndRecall matchFun expected got
|
||||||
|
Loading…
Reference in New Issue
Block a user