improve __fast_recursive_dot a bit

This commit is contained in:
Marek Kaluba 2023-03-19 20:37:46 +01:00
parent b5fa1ac0ef
commit a1de0ecc85
No known key found for this signature in database
GPG Key ID: 8BF1A3855328FC15
1 changed files with 14 additions and 11 deletions

View File

@ -122,11 +122,12 @@ sos_problem_primal(
function __fast_recursive_dot!(
res::JuMP.AffExpr,
Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}},
Ms::AbstractArray{<:AbstractSparseMatrix};
Ms::AbstractArray{<:AbstractSparseMatrix},
weights,
)
@assert length(Ps) == length(Ms)
for (A, P) in zip(Ms, Ps)
for (w, A, P) in zip(weights, Ms, Ps)
iszero(Ms) && continue
rows = rowvals(A)
vals = nonzeros(A)
@ -134,13 +135,21 @@ function __fast_recursive_dot!(
for i in nzrange(A, cidx)
ridx = rows[i]
val = vals[i]
JuMP.add_to_expression!(res, P[ridx, cidx], val)
JuMP.add_to_expression!(res, P[ridx, cidx], w * val)
end
end
end
return res
end
function _dot(
Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}},
Ms::AbstractArray{<:AbstractMatrix{T}},
weights = Iterators.repeated(one(T), length(Ms)),
) where {T}
return __fast_recursive_dot!(JuMP.AffExpr(), Ps, Ms, weights)
end
import ProgressMeter
__show_itrs(n, total) = () -> [(Symbol("constraint"), "$n/$total")]
@ -217,15 +226,9 @@ function sos_problem_primal(
# @info [nnz(m) / length(m) for m in Ms]
if feasibility_problem
JuMP.@constraint(
model,
x == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms)
)
JuMP.@constraint(model, x == _dot(P, Ms))
else
JuMP.@constraint(
model,
x - λ * u == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms)
)
JuMP.@constraint(model, x - λ * u == _dot(P, Ms))
end
end
ProgressMeter.finish!(prog)