From f053bffefe1f3ef857358cccae6e891f13d1fda3 Mon Sep 17 00:00:00 2001 From: Marek Kaluba Date: Tue, 15 Nov 2022 18:53:03 +0100 Subject: [PATCH] rewrite reconstruct! with a better architecture --- src/PropertyT.jl | 2 +- src/reconstruct.jl | 69 ++++++++++++++++++++++++++++++++++++++++++++++ src/sos_sdps.jl | 54 ------------------------------------ 3 files changed, 70 insertions(+), 55 deletions(-) create mode 100644 src/reconstruct.jl diff --git a/src/PropertyT.jl b/src/PropertyT.jl index 50495a5..ac32da9 100644 --- a/src/PropertyT.jl +++ b/src/PropertyT.jl @@ -1,4 +1,3 @@ -__precompile__() module PropertyT using LinearAlgebra @@ -16,6 +15,7 @@ import SymbolicWedderburn.PermutationGroups include("constraint_matrix.jl") include("sos_sdps.jl") include("solve.jl") +include("reconstruct.jl") include("certify.jl") include("sqadjop.jl") diff --git a/src/reconstruct.jl b/src/reconstruct.jl new file mode 100644 index 0000000..75451ee --- /dev/null +++ b/src/reconstruct.jl @@ -0,0 +1,69 @@ +__outer_dim(wd::WedderburnDecomposition) = size(first(direct_summands(wd)), 2) + +function __group_of(wd::WedderburnDecomposition) + # this is veeeery hacky... ;) + return parent(first(keys(wd.hom.cache))) +end + +function reconstruct( + Ms::AbstractVector{<:AbstractMatrix}, + wbdec::WedderburnDecomposition; + atol=eps(real(eltype(wbdec))) * 10__outer_dim(wbdec) +) + n = __outer_dim(wbdec) + res = sum(zip(Ms, SymbolicWedderburn.direct_summands(wbdec))) do (M, ds) + res = similar(M, n, n) + reconstruct!(res, M, ds, __group_of(wbdec), wbdec.hom, atol=atol) + end + return res +end + +function reconstruct!( + res::AbstractMatrix, + M::AbstractMatrix, + ds::SymbolicWedderburn.DirectSummand, + G, + hom; + atol=eps(real(eltype(ds))) * 10max(size(res)...) +) + res .= zero(eltype(res)) + U = SymbolicWedderburn.image_basis(ds) + d = SymbolicWedderburn.degree(ds) + tmp = (U' * M * U) .* d + + res = average!(res, tmp, G, hom) + if eltype(res) <: AbstractFloat + __droptol!(res, atol) # TODO: is this really necessary?! + end + return res +end + +function __droptol!(M::AbstractMatrix, tol) + for i in eachindex(M) + if abs(M[i]) < tol + M[i] = zero(M[i]) + end + end + return M +end + +# implement average! for other actions when needed +function average!( + res::AbstractMatrix, + M::AbstractMatrix, + G::Groups.Group, + hom::SymbolicWedderburn.InducedActionHomomorphism{<:SymbolicWedderburn.ByPermutations} +) + @assert size(M) == size(res) + for g in G + gext = SymbolicWedderburn.induce(hom, g) + Threads.@threads for c in axes(res, 2) + for r in axes(res, 1) + res[r, c] += M[r^gext, c^gext] + end + end + end + o = Groups.order(Int, G) + res ./= o + return res +end diff --git a/src/sos_sdps.jl b/src/sos_sdps.jl index 795ffac..f6135db 100644 --- a/src/sos_sdps.jl +++ b/src/sos_sdps.jl @@ -272,57 +272,3 @@ function sos_problem_primal( ProgressMeter.finish!(prog) return model, P end - -function reconstruct(Ps, wd::WedderburnDecomposition) - N = size(first(direct_summands(wd)), 2) - P = zeros(eltype(wd), N, N) - return reconstruct!(P, Ps, wd) -end - -function group_of(wd::WedderburnDecomposition) - # this is veeeery hacky... ;) - return parent(first(keys(wd.hom.cache))) -end - -# TODO: move to SymbolicWedderburn -SymbolicWedderburn.action(wd::WedderburnDecomposition) = - SymbolicWedderburn.action(wd.hom) - -function reconstruct!( - res::AbstractMatrix, - Ps, - wedderburn::WedderburnDecomposition, -) - G = group_of(wedderburn) - - act = SymbolicWedderburn.action(wedderburn) - - @assert act isa SymbolicWedderburn.ByPermutations - - for (π, ds) in pairs(direct_summands(wedderburn)) - Uπ = SymbolicWedderburn.image_basis(ds) - - # LinearAlgebra.mul!(tmp, Uπ', P[π]) - # LinearAlgebra.mul!(tmp2, tmp, Uπ) - tmp2 = Uπ' * Ps[π] * Uπ - if eltype(res) <: AbstractFloat - SymbolicWedderburn.zerotol!(tmp2, atol=1e-12) - end - tmp2 .*= SymbolicWedderburn.degree(ds) - - @assert size(tmp2) == size(res) - - for g in G - p = SymbolicWedderburn.induce(wedderburn.hom, g) - for c in axes(res, 2) - for r in axes(res, 1) - res[r, c] += tmp2[r^p, c^p] - end - end - end - end - res ./= Groups.order(Int, G) - - return res -end -