Groups.jl/src/wl_ball.jl

68 lines
1.9 KiB
Julia
Raw Normal View History

2021-06-21 18:16:04 +02:00
"""
wlmetric_ball(S::AbstractVector{<:GroupElem}
2022-10-14 01:14:38 +02:00
[, center=one(first(S)); radius=2, op=*, threading=true])
2021-06-21 18:16:04 +02:00
Compute metric ball as a list of elements of non-decreasing length, given the
word-length metric on the group generated by `S`. The ball is centered at `center`
(by default: the identity element). `radius` and `op` keywords specify the
radius and multiplication operation to be used.
"""
2022-10-14 01:14:38 +02:00
function wlmetric_ball(
S::AbstractVector{T},
2023-03-15 19:07:14 +01:00
center::T = one(first(S));
radius = 2,
op = *,
threading = true,
2022-10-14 01:14:38 +02:00
) where {T}
2023-03-15 19:07:14 +01:00
threading && return wlmetric_ball_thr(S, center; radius = radius, op = op)
return wlmetric_ball_serial(S, center; radius = radius, op = op)
2022-10-14 01:14:38 +02:00
end
2023-03-15 19:07:14 +01:00
function wlmetric_ball_serial(
S::AbstractVector{T},
center::T = one(first(S));
radius = 2,
op = *,
) where {T}
2021-07-17 20:23:41 +02:00
@assert radius >= 1
2023-03-15 19:07:14 +01:00
old = union!(OrderedSet([center]), [center * s for s in S])
sizes = [1, length(old)]
for _ in 2:radius
new = collect(
op(o, s) for o in @view(old.dict.keys[sizes[end-1]:end]) for s in S
)
union!(old, new)
push!(sizes, length(old))
end
return old.dict.keys, sizes[2:end]
2021-07-17 20:23:41 +02:00
end
2023-03-15 19:07:14 +01:00
function wlmetric_ball_thr(
S::AbstractVector{T},
center::T = one(first(S));
radius = 2,
op = *,
) where {T}
2021-07-17 20:23:41 +02:00
@assert radius >= 1
2022-10-14 01:14:38 +02:00
old = union!([center], [center * s for s in S])
2022-04-02 14:18:53 +02:00
return _wlmetric_ball(S, old, radius, op, Folds.collect, Folds.unique)
2021-06-21 18:16:04 +02:00
end
2021-07-17 20:23:41 +02:00
function _wlmetric_ball(S, old, radius, op, collect, unique)
2021-06-21 18:16:04 +02:00
sizes = [1, length(old)]
2023-03-15 19:07:14 +01:00
for _ in 2:radius
old = let old = old, S = S
2021-07-17 20:23:41 +02:00
new = collect(
2023-03-15 19:07:14 +01:00
(g = op(o, s);
normalform!(g);
hash(g);
g) for o in @view(old[sizes[end-1]:end]) for s in S
2021-07-17 20:23:41 +02:00
)
2022-10-14 01:14:38 +02:00
2021-07-17 20:23:41 +02:00
append!(old, new)
unique(old)
2021-06-21 18:16:04 +02:00
end
push!(sizes, length(old))
end
return old, sizes[2:end]
end