1
0
mirror of https://github.com/kalmarek/PropertyT.jl.git synced 2024-11-14 22:20:28 +01:00
PropertyT.jl/src/sos_sdps.jl

321 lines
9.6 KiB
Julia
Raw Normal View History

2018-09-05 10:41:11 +02:00
###############################################################################
#
# Constraints
#
###############################################################################
2017-03-13 14:49:55 +01:00
2018-08-20 03:50:03 +02:00
function constraints(pm::Matrix{I}, total_length=maximum(pm)) where {I<:Integer}
cnstrs = [Vector{I}() for _ in 1:total_length]
for i in eachindex(pm)
push!(cnstrs[pm[i]], i)
end
return cnstrs
end
2018-09-05 10:37:39 +02:00
function orbit_constraint!(result::SparseMatrixCSC, cnstrs, orbit; val=1.0/length(orbit))
result .= zero(eltype(result))
dropzeros!(result)
2018-09-05 14:34:57 +02:00
for constraint in cnstrs[orbit]
2018-09-05 10:37:39 +02:00
for idx in constraint
result[idx] = val
2017-03-13 14:49:55 +01:00
end
end
2018-09-05 10:37:39 +02:00
return result
2017-03-13 14:49:55 +01:00
end
function orbit_spvector(vect::AbstractVector, orbits)
orb_vector = spzeros(length(orbits))
for (i,o) in enumerate(orbits)
k = vect[collect(o)]
val = k[1]
@assert all(k .== val)
orb_vector[i] = val
end
return orb_vector
end
2018-09-05 10:41:11 +02:00
###############################################################################
#
# Naive SDP
#
###############################################################################
2017-03-13 14:49:55 +01:00
function SOS_problem(X::GroupRingElem, orderunit::GroupRingElem; upper_bound=Inf)
N = size(parent(X).pm, 1)
2017-03-13 14:49:55 +01:00
m = JuMP.Model();
JuMP.@variable(m, P[1:N, 1:N])
2017-04-01 15:21:57 +02:00
JuMP.@SDconstraint(m, P >= 0)
JuMP.@constraint(m, sum(P[i] for i in eachindex(P)) == 0)
2017-03-20 21:41:12 +01:00
JuMP.@variable(m, λ)
2017-03-13 14:49:55 +01:00
if upper_bound < Inf
JuMP.@constraint(m, λ <= upper_bound)
2017-03-13 14:49:55 +01:00
end
2018-09-05 10:41:11 +02:00
cnstrs = constraints(parent(X).pm)
for (constraint, x, u) in zip(cnstrs, X.coeffs, orderunit.coeffs)
JuMP.@constraint(m, sum(P[constraint]) == x - λ*u)
2017-03-13 14:49:55 +01:00
end
2017-03-20 21:41:12 +01:00
2017-04-01 15:21:57 +02:00
JuMP.@objective(m, Max, λ)
return m, λ, P
2017-03-13 14:49:55 +01:00
end
2018-09-05 09:18:38 +02:00
###############################################################################
#
# Symmetrized SDP
#
###############################################################################
function SOS_problem(X::GroupRingElem, orderunit::GroupRingElem, data::OrbitData; upper_bound=Inf)
2018-09-05 10:41:11 +02:00
Ns = size.(data.Uπs, 2)
2018-09-05 09:18:38 +02:00
m = JuMP.Model();
2019-01-11 06:32:09 +01:00
P = Vector{Matrix{JuMP.Variable}}(undef, length(Ns))
2018-09-05 09:18:38 +02:00
2018-09-05 10:41:11 +02:00
for (k,s) in enumerate(Ns)
2018-09-05 09:18:38 +02:00
P[k] = JuMP.@variable(m, [i=1:s, j=1:s])
JuMP.@SDconstraint(m, P[k] >= 0.0)
end
λ = JuMP.@variable(m, λ)
if upper_bound < Inf
JuMP.@constraint(m, λ <= upper_bound)
end
2019-01-11 06:32:09 +01:00
@info("Adding $(length(data.orbits)) constraints... ")
2018-09-05 09:18:38 +02:00
2018-09-05 14:34:57 +02:00
@time addconstraints!(m,P,λ,X,orderunit, data)
2018-09-05 09:18:38 +02:00
JuMP.@objective(m, Max, λ)
return m, λ, P
end
2018-08-20 03:45:50 +02:00
2018-09-05 09:18:38 +02:00
function constraintLHS!(M, cnstr, Us, Ust, dims, eps=1000*eps(eltype(first(M))))
2019-01-28 08:47:40 +01:00
for π in 1:lastindex(Us)
2019-01-10 04:48:30 +01:00
M[π] = dims[π].*PropertyT.clamp_small!(Ust[π]*cnstr*Us[π], eps)
2018-08-20 03:45:50 +02:00
end
2018-09-05 09:18:38 +02:00
end
function addconstraints!(m::JuMP.Model,
P::Vector{Matrix{JuMP.Variable}}, λ::JuMP.Variable,
X::GroupRingElem, orderunit::GroupRingElem, data::OrbitData)
orderunit_orb = orbit_spvector(orderunit.coeffs, data.orbits)
X_orb = orbit_spvector(X.coeffs, data.orbits)
UπsT = [U' for U in data.Uπs]
2018-08-20 03:45:50 +02:00
2018-09-05 09:18:38 +02:00
cnstrs = constraints(parent(X).pm)
orb_cnstr = spzeros(Float64, size(parent(X).pm)...)
2019-01-11 06:32:09 +01:00
M = [Array{Float64}(undef, n,n) for n in size.(UπsT,1)]
2018-09-05 09:18:38 +02:00
2018-09-05 10:41:11 +02:00
for (t, orbit) in enumerate(data.orbits)
orbit_constraint!(orb_cnstr, cnstrs, orbit)
2018-09-05 09:18:38 +02:00
constraintLHS!(M, orb_cnstr, data.Uπs, UπsT, data.dims)
2019-01-11 06:32:09 +01:00
lhs = @expression(m, sum(dot(M[π], P[π]) for π in eachindex(data.Uπs)))
2018-09-05 09:18:38 +02:00
x, u = X_orb[t], orderunit_orb[t]
JuMP.@constraint(m, lhs == x - λ*u)
end
end
function reconstruct(Ps::Vector{Matrix{F}}, data::OrbitData) where F
return reconstruct(Ps, data.preps, data.Uπs, data.dims)
end
2018-09-05 10:41:11 +02:00
function reconstruct(Ps::Vector{M},
preps::Dict{GEl, P}, Uπs::Vector{U}, dims::Vector{Int}) where
{M<:AbstractMatrix, GEl<:GroupElem, P<:perm, U<:AbstractMatrix}
2018-09-05 09:18:38 +02:00
lU = length(Uπs)
transfP = [dims[π].*Uπs[π]*Ps[π]*Uπs[π]' for π in 1:lU]
tmp = [zeros(Float64, size(first(transfP))) for _ in 1:lU]
2019-01-14 17:46:13 +01:00
Threads.@threads for π in 1:lU
tmp[π] = perm_avg(tmp[π], transfP[π], values(preps))
2018-09-05 09:18:38 +02:00
end
2018-08-20 03:45:50 +02:00
2018-11-24 15:02:28 +01:00
recP = sum(tmp)./length(preps)
2018-09-05 09:18:38 +02:00
return recP
2018-08-20 03:45:50 +02:00
end
function perm_avg(result, P, perms)
lp = length(first(perms).d)
for p in perms
# result .+= view(P, p.d, p.d)
@inbounds for j in 1:lp
k = p[j]
for i in 1:lp
result[i,j] += P[p[i], k]
end
end
end
return result
end
###############################################################################
#
# Low-level solve
#
###############################################################################
using MathProgBase
function solve(m::JuMP.Model, varλ::JuMP.Variable, varP, warmstart=nothing)
2017-03-13 14:49:55 +01:00
2018-01-02 02:52:45 +01:00
traits = JuMP.ProblemTraits(m, relaxation=false)
2017-03-16 09:35:32 +01:00
2018-01-02 02:52:45 +01:00
JuMP.build(m, traits=traits)
if warmstart != nothing
p_sol, d_sol, s = warmstart
MathProgBase.SolverInterface.setwarmstart!(m.internalModel, p_sol;
dual_sol=d_sol, slack=s);
2018-01-02 02:52:45 +01:00
end
2017-03-16 09:35:32 +01:00
2018-01-02 02:52:45 +01:00
MathProgBase.optimize!(m.internalModel)
status = MathProgBase.status(m.internalModel)
2017-03-13 14:49:55 +01:00
2018-01-02 02:52:45 +01:00
λ = MathProgBase.getobjval(m.internalModel)
warmstart = (m.internalModel.primal_sol, m.internalModel.dual_sol,
m.internalModel.slack)
fillfrominternal!(m, traits)
P = JuMP.getvalue(varP)
λ = JuMP.getvalue(varλ)
2017-03-13 14:49:55 +01:00
return status, (λ, P, warmstart)
2017-03-13 14:49:55 +01:00
end
2018-01-01 23:59:31 +01:00
function solve(solverlog::String, model::JuMP.Model, varλ::JuMP.Variable, varP, warmstart=nothing)
isdir(dirname(solverlog)) || mkpath(dirname(solverlog))
2019-01-11 06:32:09 +01:00
Base.flush(Base.stdout)
2019-01-02 10:03:01 +01:00
status, (λ, P, warmstart) = open(solverlog, "a+") do logfile
redirect_stdout(logfile) do
status, (λ, P, warmstart) = PropertyT.solve(model, varλ, varP, warmstart)
Base.Libc.flush_cstdio()
status, (λ, P, warmstart)
end
end
return status, (λ, P, warmstart)
end
###############################################################################
#
# Copied from JuMP/src/solvers.jl:178
#
###############################################################################
2018-01-01 23:59:31 +01:00
function fillfrominternal!(m::JuMP.Model, traits)
stat::Symbol = MathProgBase.status(m.internalModel)
numRows, numCols = length(m.linconstr), m.numCols
m.objBound = NaN
m.objVal = NaN
m.colVal = fill(NaN, numCols)
2019-01-11 06:32:09 +01:00
m.linconstrDuals = Array{Float64}(undef, 0)
2018-01-01 23:59:31 +01:00
discrete = (traits.int || traits.sos)
if stat == :Optimal
# If we think dual information might be available, try to get it
# If not, return an array of the correct length
if discrete
m.redCosts = fill(NaN, numCols)
m.linconstrDuals = fill(NaN, numRows)
else
if !traits.conic
m.redCosts = try
MathProgBase.getreducedcosts(m.internalModel)[1:numCols]
catch
fill(NaN, numCols)
end
m.linconstrDuals = try
MathProgBase.getconstrduals(m.internalModel)[1:numRows]
catch
fill(NaN, numRows)
end
elseif !traits.qp && !traits.qc
JuMP.fillConicDuals(m)
end
end
else
# Problem was not solved to optimality, attempt to extract useful
# information anyway
if traits.lin
if stat == :Infeasible
m.linconstrDuals = try
infray = MathProgBase.getinfeasibilityray(m.internalModel)
@assert length(infray) == numRows
infray
catch
2019-01-11 06:32:09 +01:00
@warn("Infeasibility ray (Farkas proof) not available")
2018-01-01 23:59:31 +01:00
fill(NaN, numRows)
end
elseif stat == :Unbounded
m.colVal = try
unbdray = MathProgBase.getunboundedray(m.internalModel)
@assert length(unbdray) == numCols
unbdray
catch
2019-01-11 06:32:09 +01:00
@warn("Unbounded ray not available")
2018-01-01 23:59:31 +01:00
fill(NaN, numCols)
end
end
end
# conic duals (currently, SOC and SDP only)
if !discrete && traits.conic && !traits.qp && !traits.qc
if stat == :Infeasible
JuMP.fillConicDuals(m)
end
end
end
# If the problem was solved, or if it terminated prematurely, try
# to extract a solution anyway. This commonly occurs when a time
# limit or tolerance is set (:UserLimit)
if !(stat == :Infeasible || stat == :Unbounded)
try
# Do a separate try since getobjval could work while getobjbound does not and vice versa
objBound = MathProgBase.getobjbound(m.internalModel) + m.obj.aff.constant
m.objBound = objBound
2019-01-11 06:32:09 +01:00
catch
@warn("objBound could not be obtained")
2018-01-01 23:59:31 +01:00
end
2019-01-11 06:32:09 +01:00
2018-01-01 23:59:31 +01:00
try
objVal = MathProgBase.getobjval(m.internalModel) + m.obj.aff.constant
colVal = MathProgBase.getsolution(m.internalModel)[1:numCols]
# Rescale off-diagonal terms of SDP variables
if traits.sdp
offdiagvars = JuMP.offdiagsdpvars(m)
colVal[offdiagvars] /= sqrt(2)
end
# Don't corrupt the answers if one of the above two calls fails
m.objVal = objVal
m.colVal = colVal
2019-01-11 06:32:09 +01:00
catch
@warn("objVal/colVal could not be obtained")
2018-01-01 23:59:31 +01:00
end
end
2019-01-11 06:32:09 +01:00
if traits.conic && m.objSense == :Max
m.objBound = -1 * (m.objBound - m.obj.aff.constant) + m.obj.aff.constant
m.objVal = -1 * (m.objVal - m.obj.aff.constant) + m.obj.aff.constant
end
2018-01-01 23:59:31 +01:00
return stat
end