diff --git a/src/sos_sdps.jl b/src/sos_sdps.jl index 04a8916..e256bd0 100644 --- a/src/sos_sdps.jl +++ b/src/sos_sdps.jl @@ -122,11 +122,12 @@ sos_problem_primal( function __fast_recursive_dot!( res::JuMP.AffExpr, Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}}, - Ms::AbstractArray{<:AbstractSparseMatrix}; + Ms::AbstractArray{<:AbstractSparseMatrix}, + weights, ) @assert length(Ps) == length(Ms) - for (A, P) in zip(Ms, Ps) + for (w, A, P) in zip(weights, Ms, Ps) iszero(Ms) && continue rows = rowvals(A) vals = nonzeros(A) @@ -134,13 +135,21 @@ function __fast_recursive_dot!( for i in nzrange(A, cidx) ridx = rows[i] val = vals[i] - JuMP.add_to_expression!(res, P[ridx, cidx], val) + JuMP.add_to_expression!(res, P[ridx, cidx], w * val) end end end return res end +function _dot( + Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}}, + Ms::AbstractArray{<:AbstractMatrix{T}}, + weights = Iterators.repeated(one(T), length(Ms)), +) where {T} + return __fast_recursive_dot!(JuMP.AffExpr(), Ps, Ms, weights) +end + import ProgressMeter __show_itrs(n, total) = () -> [(Symbol("constraint"), "$n/$total")] @@ -217,15 +226,9 @@ function sos_problem_primal( # @info [nnz(m) / length(m) for m in Ms] if feasibility_problem - JuMP.@constraint( - model, - x == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms) - ) + JuMP.@constraint(model, x == _dot(P, Ms)) else - JuMP.@constraint( - model, - x - λ * u == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms) - ) + JuMP.@constraint(model, x - λ * u == _dot(P, Ms)) end end ProgressMeter.finish!(prog)