1
0
mirror of https://github.com/kalmarek/PropertyT.jl.git synced 2024-11-29 09:45:27 +01:00

improve __fast_recursive_dot a bit

This commit is contained in:
Marek Kaluba 2023-03-19 20:37:46 +01:00
parent b5fa1ac0ef
commit a1de0ecc85
No known key found for this signature in database
GPG Key ID: 8BF1A3855328FC15

View File

@ -122,11 +122,12 @@ sos_problem_primal(
function __fast_recursive_dot!( function __fast_recursive_dot!(
res::JuMP.AffExpr, res::JuMP.AffExpr,
Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}}, Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}},
Ms::AbstractArray{<:AbstractSparseMatrix}; Ms::AbstractArray{<:AbstractSparseMatrix},
weights,
) )
@assert length(Ps) == length(Ms) @assert length(Ps) == length(Ms)
for (A, P) in zip(Ms, Ps) for (w, A, P) in zip(weights, Ms, Ps)
iszero(Ms) && continue iszero(Ms) && continue
rows = rowvals(A) rows = rowvals(A)
vals = nonzeros(A) vals = nonzeros(A)
@ -134,13 +135,21 @@ function __fast_recursive_dot!(
for i in nzrange(A, cidx) for i in nzrange(A, cidx)
ridx = rows[i] ridx = rows[i]
val = vals[i] val = vals[i]
JuMP.add_to_expression!(res, P[ridx, cidx], val) JuMP.add_to_expression!(res, P[ridx, cidx], w * val)
end end
end end
end end
return res return res
end end
function _dot(
Ps::AbstractArray{<:AbstractMatrix{<:JuMP.VariableRef}},
Ms::AbstractArray{<:AbstractMatrix{T}},
weights = Iterators.repeated(one(T), length(Ms)),
) where {T}
return __fast_recursive_dot!(JuMP.AffExpr(), Ps, Ms, weights)
end
import ProgressMeter import ProgressMeter
__show_itrs(n, total) = () -> [(Symbol("constraint"), "$n/$total")] __show_itrs(n, total) = () -> [(Symbol("constraint"), "$n/$total")]
@ -217,15 +226,9 @@ function sos_problem_primal(
# @info [nnz(m) / length(m) for m in Ms] # @info [nnz(m) / length(m) for m in Ms]
if feasibility_problem if feasibility_problem
JuMP.@constraint( JuMP.@constraint(model, x == _dot(P, Ms))
model,
x == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms)
)
else else
JuMP.@constraint( JuMP.@constraint(model, x - λ * u == _dot(P, Ms))
model,
x - λ * u == __fast_recursive_dot!(JuMP.AffExpr(), P, Ms)
)
end end
end end
ProgressMeter.finish!(prog) ProgressMeter.finish!(prog)