rewrite reconstruct! with a better architecture

This commit is contained in:
Marek Kaluba 2022-11-15 18:53:03 +01:00
parent 35c5110a37
commit f053bffefe
No known key found for this signature in database
GPG Key ID: 8BF1A3855328FC15
3 changed files with 70 additions and 55 deletions

View File

@ -1,4 +1,3 @@
__precompile__()
module PropertyT
using LinearAlgebra
@ -16,6 +15,7 @@ import SymbolicWedderburn.PermutationGroups
include("constraint_matrix.jl")
include("sos_sdps.jl")
include("solve.jl")
include("reconstruct.jl")
include("certify.jl")
include("sqadjop.jl")

69
src/reconstruct.jl Normal file
View File

@ -0,0 +1,69 @@
__outer_dim(wd::WedderburnDecomposition) = size(first(direct_summands(wd)), 2)
function __group_of(wd::WedderburnDecomposition)
# this is veeeery hacky... ;)
return parent(first(keys(wd.hom.cache)))
end
function reconstruct(
Ms::AbstractVector{<:AbstractMatrix},
wbdec::WedderburnDecomposition;
atol=eps(real(eltype(wbdec))) * 10__outer_dim(wbdec)
)
n = __outer_dim(wbdec)
res = sum(zip(Ms, SymbolicWedderburn.direct_summands(wbdec))) do (M, ds)
res = similar(M, n, n)
reconstruct!(res, M, ds, __group_of(wbdec), wbdec.hom, atol=atol)
end
return res
end
function reconstruct!(
res::AbstractMatrix,
M::AbstractMatrix,
ds::SymbolicWedderburn.DirectSummand,
G,
hom;
atol=eps(real(eltype(ds))) * 10max(size(res)...)
)
res .= zero(eltype(res))
U = SymbolicWedderburn.image_basis(ds)
d = SymbolicWedderburn.degree(ds)
tmp = (U' * M * U) .* d
res = average!(res, tmp, G, hom)
if eltype(res) <: AbstractFloat
__droptol!(res, atol) # TODO: is this really necessary?!
end
return res
end
function __droptol!(M::AbstractMatrix, tol)
for i in eachindex(M)
if abs(M[i]) < tol
M[i] = zero(M[i])
end
end
return M
end
# implement average! for other actions when needed
function average!(
res::AbstractMatrix,
M::AbstractMatrix,
G::Groups.Group,
hom::SymbolicWedderburn.InducedActionHomomorphism{<:SymbolicWedderburn.ByPermutations}
)
@assert size(M) == size(res)
for g in G
gext = SymbolicWedderburn.induce(hom, g)
Threads.@threads for c in axes(res, 2)
for r in axes(res, 1)
res[r, c] += M[r^gext, c^gext]
end
end
end
o = Groups.order(Int, G)
res ./= o
return res
end

View File

@ -272,57 +272,3 @@ function sos_problem_primal(
ProgressMeter.finish!(prog)
return model, P
end
function reconstruct(Ps, wd::WedderburnDecomposition)
N = size(first(direct_summands(wd)), 2)
P = zeros(eltype(wd), N, N)
return reconstruct!(P, Ps, wd)
end
function group_of(wd::WedderburnDecomposition)
# this is veeeery hacky... ;)
return parent(first(keys(wd.hom.cache)))
end
# TODO: move to SymbolicWedderburn
SymbolicWedderburn.action(wd::WedderburnDecomposition) =
SymbolicWedderburn.action(wd.hom)
function reconstruct!(
res::AbstractMatrix,
Ps,
wedderburn::WedderburnDecomposition,
)
G = group_of(wedderburn)
act = SymbolicWedderburn.action(wedderburn)
@assert act isa SymbolicWedderburn.ByPermutations
for (π, ds) in pairs(direct_summands(wedderburn))
= SymbolicWedderburn.image_basis(ds)
# LinearAlgebra.mul!(tmp, Uπ', P[π])
# LinearAlgebra.mul!(tmp2, tmp, Uπ)
tmp2 = ' * Ps[π] *
if eltype(res) <: AbstractFloat
SymbolicWedderburn.zerotol!(tmp2, atol=1e-12)
end
tmp2 .*= SymbolicWedderburn.degree(ds)
@assert size(tmp2) == size(res)
for g in G
p = SymbolicWedderburn.induce(wedderburn.hom, g)
for c in axes(res, 2)
for r in axes(res, 1)
res[r, c] += tmp2[r^p, c^p]
end
end
end
end
res ./= Groups.order(Int, G)
return res
end