diff --git a/src/sos_sdps.jl b/src/sos_sdps.jl index d2387ed..8f086c7 100644 --- a/src/sos_sdps.jl +++ b/src/sos_sdps.jl @@ -160,6 +160,28 @@ sos_problem_primal( kwargs... ) = sos_problem_primal(elt, zero(elt), wedderburn; kwargs...) +function __fast_recursive_dot!( + res::JuMP.AffExpr, + Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}}, + Ms::AbstractArray{<:AbstractSparseMatrix}; +) + @assert length(Ps) == length(Ms) + + for (A, P) in zip(Ms, Ps) + iszero(Ms) && continue + rows = rowvals(A) + vals = nonzeros(A) + for cidx in axes(A, 2) + for i in nzrange(A, cidx) + ridx = rows[i] + val = vals[i] + JuMP.add_to_expression!(res, P[ridx, cidx], val) + end + end + end + return res +end + function sos_problem_primal( elt::StarAlgebras.AlgebraElement, orderunit::StarAlgebras.AlgebraElement, @@ -225,9 +247,9 @@ function sos_problem_primal( M_dot_P = sum(dot(M[π], P[π]) for π in eachindex(M) if !iszero(M[π])) if feasibility_problem - JuMP.@constraint(model, x == M_dot_P) + JuMP.@constraint(model, x == __fast_recursive_dot!(JuMP.AffExpr(), P, M)) else - JuMP.@constraint(model, x - λ * u == M_dot_P) + JuMP.@constraint(model, x - λ * u == __fast_recursive_dot!(JuMP.AffExpr(), P, M)) end end return model, P