in reconstruct: average the sum, not sum the averages!

This commit is contained in:
Marek Kaluba 2023-04-04 19:58:51 +02:00
parent f9f852439f
commit 1a43a1b1be
No known key found for this signature in database
GPG Key ID: 8BF1A3855328FC15
1 changed files with 13 additions and 14 deletions

View File

@ -12,9 +12,10 @@ function reconstruct(
n = __outer_dim(wbdec) n = __outer_dim(wbdec)
res = sum(zip(Ms, SymbolicWedderburn.direct_summands(wbdec))) do (M, ds) res = sum(zip(Ms, SymbolicWedderburn.direct_summands(wbdec))) do (M, ds)
res = similar(M, n, n) res = similar(M, n, n)
res = _reconstruct!(res, M, ds, __group_of(wbdec), wbdec.hom) res = _reconstruct!(res, M, ds)
return res return res
end end
res = average!(zero(res), res, __group_of(wbdec), wbdec.hom)
return res return res
end end
@ -22,16 +23,14 @@ function _reconstruct!(
res::AbstractMatrix, res::AbstractMatrix,
M::AbstractMatrix, M::AbstractMatrix,
ds::SymbolicWedderburn.DirectSummand, ds::SymbolicWedderburn.DirectSummand,
G,
hom,
) )
U = SymbolicWedderburn.image_basis(ds)
d = SymbolicWedderburn.degree(ds)
Θπ = (U' * M * U) .* d
res .= zero(eltype(res)) res .= zero(eltype(res))
Θπ = average!(res, Θπ, G, hom) if !iszero(M)
return Θπ U = SymbolicWedderburn.image_basis(ds)
d = SymbolicWedderburn.degree(ds)
res = (U' * M * U) .* d
end
return res
end end
function __droptol!(M::AbstractMatrix, tol) function __droptol!(M::AbstractMatrix, tol)
@ -52,18 +51,18 @@ function average!(
<:SymbolicWedderburn.ByPermutations, <:SymbolicWedderburn.ByPermutations,
}, },
) )
res .= zero(eltype(res))
@assert size(M) == size(res) @assert size(M) == size(res)
o = Groups.order(Int, G)
for g in G for g in G
p = SymbolicWedderburn.induce(hom, g) p = SymbolicWedderburn.induce(hom, g)
Threads.@threads for c in axes(res, 2) Threads.@threads for c in axes(res, 2)
# note: p is a permutation,
# so accesses below are guaranteed to be disjoint
for r in axes(res, 1) for r in axes(res, 1)
res[r^p, c^p] += M[r, c] if !iszero(M[r, c])
res[r^p, c^p] += M[r, c] / o
end
end end
end end
end end
o = Groups.order(Int, G)
res ./= o
return res return res
end end