From bcce338754e5ce8935a2d81147fb06725b31ea16 Mon Sep 17 00:00:00 2001 From: Marek Kaluba Date: Sat, 17 Jul 2021 20:23:41 +0200 Subject: [PATCH] refactor wlmetric_ball --- src/wl_ball.jl | 52 +++++++++++++++++--------------------------------- 1 file changed, 18 insertions(+), 34 deletions(-) diff --git a/src/wl_ball.jl b/src/wl_ball.jl index f7f4e12..ebb370c 100644 --- a/src/wl_ball.jl +++ b/src/wl_ball.jl @@ -6,50 +6,34 @@ word-length metric on the group generated by `S`. The ball is centered at `cente (by default: the identity element). `radius` and `op` keywords specify the radius and multiplication operation to be used. """ -function wlmetric_ball_serial(S::AbstractVector{T}; radius = 2, op = *) where {T} - @assert radius > 0 - old = unique!([one(first(S)), S...]) - sizes = [1, length(old)] - for i in 2:radius - new = collect(op(o, s) for o in @view(old[sizes[end-1]:end]) for s in S) - append!(old, new) - resize!(new, 0) - old = unique!(old) - push!(sizes, length(old)) - end - return old, sizes[2:end] +function wlmetric_ball_serial(S::AbstractVector{T}, center::T=one(first(S)); radius = 2, op = *) where {T} + @assert radius >= 1 + old = unique!([center, [center*s for s in S]...]) + return _wlmetric_ball(S, old, radius, op, collect, unique!) end -function wlmetric_ball_thr(S::AbstractVector{T}; radius = 2, op = *) where {T} - @assert radius > 0 - old = unique!([one(first(S)), S...]) +function wlmetric_ball_thr(S::AbstractVector{T}, center::T=one(first(S)); radius = 2, op = *) where {T} + @assert radius >= 1 + old = unique!([center, [center*s for s in S]...]) + return _wlmetric_ball(S, old, radius, op, ThreadsX.collect, ThreadsX.unique) +end + +function _wlmetric_ball(S, old, radius, op, collect, unique) sizes = [1, length(old)] for r in 2:radius - begin - new = - ThreadsX.collect(op(o, s) for o in @view(old[sizes[end-1]:end]) for s in S) - ThreadsX.foreach(hash, new) + old = let old = old, S = S, + new = collect( + (g = op(o, s); hash(g); g) + for o in @view(old[sizes[end-1]:end]) for s in S + ) + append!(old, new) + unique(old) end - append!(old, new) - resize!(new, 0) - old = ThreadsX.unique(old) push!(sizes, length(old)) end return old, sizes[2:end] end -function wlmetric_ball_serial(S::AbstractVector{T}, center::T; radius = 2, op = *) where {T} - E, sizes = wlmetric_ball_serial(S, radius = radius, op = op) - isone(center) && return E, sizes - return c .* E, sizes -end - -function wlmetric_ball_thr(S::AbstractVector{T}, center::T; radius = 2, op = *) where {T} - E, sizes = wlmetric_ball_thr(S, radius = radius, op = op) - isone(center) && return E, sizes - return c .* E, sizes -end - function wlmetric_ball( S::AbstractVector{T}, center::T = one(first(S));