add __fast_recursive_dot to speed up constraints

This commit is contained in:
Marek Kaluba 2022-11-14 19:47:38 +01:00
parent 227e82d551
commit 2f89538eb0
No known key found for this signature in database
GPG Key ID: 8BF1A3855328FC15
1 changed files with 24 additions and 2 deletions

View File

@ -160,6 +160,28 @@ sos_problem_primal(
kwargs...
) = sos_problem_primal(elt, zero(elt), wedderburn; kwargs...)
function __fast_recursive_dot!(
res::JuMP.AffExpr,
Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}},
Ms::AbstractArray{<:AbstractSparseMatrix};
)
@assert length(Ps) == length(Ms)
for (A, P) in zip(Ms, Ps)
iszero(Ms) && continue
rows = rowvals(A)
vals = nonzeros(A)
for cidx in axes(A, 2)
for i in nzrange(A, cidx)
ridx = rows[i]
val = vals[i]
JuMP.add_to_expression!(res, P[ridx, cidx], val)
end
end
end
return res
end
function sos_problem_primal(
elt::StarAlgebras.AlgebraElement,
orderunit::StarAlgebras.AlgebraElement,
@ -225,9 +247,9 @@ function sos_problem_primal(
M_dot_P = sum(dot(M[π], P[π]) for π in eachindex(M) if !iszero(M[π]))
if feasibility_problem
JuMP.@constraint(model, x == M_dot_P)
JuMP.@constraint(model, x == __fast_recursive_dot!(JuMP.AffExpr(), P, M))
else
JuMP.@constraint(model, x - λ * u == M_dot_P)
JuMP.@constraint(model, x - λ * u == __fast_recursive_dot!(JuMP.AffExpr(), P, M))
end
end
return model, P