PropertyT.jl/src/sos_sdps.jl

227 lines
6.3 KiB
Julia
Raw Normal View History

2018-09-05 10:41:11 +02:00
###############################################################################
#
# Constraints
#
###############################################################################
2017-03-13 14:49:55 +01:00
2018-08-20 03:50:03 +02:00
function constraints(pm::Matrix{I}, total_length=maximum(pm)) where {I<:Integer}
cnstrs = [Vector{I}() for _ in 1:total_length]
for i in eachindex(pm)
push!(cnstrs[pm[i]], i)
end
return cnstrs
end
2018-09-05 10:37:39 +02:00
function orbit_constraint!(result::SparseMatrixCSC, cnstrs, orbit; val=1.0/length(orbit))
result .= zero(eltype(result))
dropzeros!(result)
2018-09-05 14:34:57 +02:00
for constraint in cnstrs[orbit]
2018-09-05 10:37:39 +02:00
for idx in constraint
result[idx] = val
2017-03-13 14:49:55 +01:00
end
end
2018-09-05 10:37:39 +02:00
return result
2017-03-13 14:49:55 +01:00
end
function orbit_spvector(vect::AbstractVector, orbits)
orb_vector = spzeros(length(orbits))
for (i,o) in enumerate(orbits)
k = vect[collect(o)]
val = k[1]
@assert all(k .== val)
orb_vector[i] = val
end
return orb_vector
end
2018-09-05 10:41:11 +02:00
###############################################################################
#
# Naive SDP
#
###############################################################################
2017-03-13 14:49:55 +01:00
function SOS_problem(X::GroupRingElem, orderunit::GroupRingElem; upper_bound::Float64=Inf)
N = size(parent(X).pm, 1)
2017-03-13 14:49:55 +01:00
m = JuMP.Model();
JuMP.@variable(m, P[1:N, 1:N])
2019-08-02 11:52:19 +02:00
# SP = Symmetric(P)
JuMP.@constraint(m, sdp, P in PSDCone())
2017-04-01 15:21:57 +02:00
JuMP.@constraint(m, sum(P[i] for i in eachindex(P)) == 0)
2017-03-20 21:41:12 +01:00
2017-03-13 14:49:55 +01:00
if upper_bound < Inf
λ = JuMP.@variable(m, λ <= upper_bound)
else
λ = JuMP.@variable(m, λ)
2017-03-13 14:49:55 +01:00
end
2019-04-12 23:18:48 +02:00
2018-09-05 10:41:11 +02:00
cnstrs = constraints(parent(X).pm)
2019-08-02 11:52:19 +02:00
@assert length(cnstrs) == length(X.coeffs) == length(orderunit.coeffs)
x, u = X.coeffs, orderunit.coeffs
JuMP.@constraint(m, lincnstr[i=1:length(cnstrs)],
x[i] - λ*u[i] == sum(P[cnstrs[i]]))
2019-04-12 23:18:48 +02:00
2017-04-01 15:21:57 +02:00
JuMP.@objective(m, Max, λ)
2019-04-12 23:18:48 +02:00
return m
2017-03-13 14:49:55 +01:00
end
2018-09-05 09:18:38 +02:00
###############################################################################
#
# Symmetrized SDP
#
###############################################################################
function SOS_problem(X::GroupRingElem, orderunit::GroupRingElem, data::OrbitData; upper_bound::Float64=Inf)
2018-09-05 10:41:11 +02:00
Ns = size.(data.Uπs, 2)
2018-09-05 09:18:38 +02:00
m = JuMP.Model();
Ps = Vector{Matrix{JuMP.VariableRef}}(undef, length(Ns))
2018-09-05 09:18:38 +02:00
2018-09-05 10:41:11 +02:00
for (k,s) in enumerate(Ns)
Ps[k] = JuMP.@variable(m, [1:s, 1:s])
JuMP.@constraint(m, Ps[k] in PSDCone())
2018-09-05 09:18:38 +02:00
end
if upper_bound < Inf
λ = JuMP.@variable(m, λ <= upper_bound)
else
λ = JuMP.@variable(m, λ)
2018-09-05 09:18:38 +02:00
end
2019-04-12 23:18:48 +02:00
2019-02-24 00:17:20 +01:00
@info "Adding $(length(data.orbits)) constraints..."
@time addconstraints!(m, Ps, X, orderunit, data)
2018-09-05 09:18:38 +02:00
JuMP.@objective(m, Max, λ)
2019-04-12 23:18:48 +02:00
return m, Ps
2018-09-05 09:18:38 +02:00
end
2018-08-20 03:45:50 +02:00
2018-09-05 09:18:38 +02:00
function constraintLHS!(M, cnstr, Us, Ust, dims, eps=1000*eps(eltype(first(M))))
2019-01-28 08:47:40 +01:00
for π in 1:lastindex(Us)
M[π] = dims[π].*PropertyT.clamp_small!(Ust[π]*(cnstr*Us[π]), eps)
2018-08-20 03:45:50 +02:00
end
2018-09-05 09:18:38 +02:00
end
function addconstraints!(m::JuMP.Model,
P::Vector{Matrix{JuMP.VariableRef}},
2018-09-05 09:18:38 +02:00
X::GroupRingElem, orderunit::GroupRingElem, data::OrbitData)
orderunit_orb = orbit_spvector(orderunit.coeffs, data.orbits)
X_orb = orbit_spvector(X.coeffs, data.orbits)
UπsT = [U' for U in data.Uπs]
2018-08-20 03:45:50 +02:00
2018-09-05 09:18:38 +02:00
cnstrs = constraints(parent(X).pm)
orb_cnstr = spzeros(Float64, size(parent(X).pm)...)
2019-01-11 06:32:09 +01:00
M = [Array{Float64}(undef, n,n) for n in size.(UπsT,1)]
2019-08-02 11:52:19 +02:00
λ = m[]
2018-09-05 09:18:38 +02:00
2018-09-05 10:41:11 +02:00
for (t, orbit) in enumerate(data.orbits)
orbit_constraint!(orb_cnstr, cnstrs, orbit)
2018-09-05 09:18:38 +02:00
constraintLHS!(M, orb_cnstr, data.Uπs, UπsT, data.dims)
2019-04-12 23:18:48 +02:00
2018-09-05 09:18:38 +02:00
x, u = X_orb[t], orderunit_orb[t]
2019-04-12 23:18:48 +02:00
JuMP.@constraints m begin
x - λ*u == sum(dot(M[π], P[π]) for π in eachindex(data.Uπs))
end
2018-09-05 09:18:38 +02:00
end
return m
2018-09-05 09:18:38 +02:00
end
function reconstruct(Ps::Vector{Matrix{F}}, data::OrbitData) where F
return reconstruct(Ps, data.preps, data.Uπs, data.dims)
end
2018-09-05 10:41:11 +02:00
function reconstruct(Ps::Vector{M},
preps::Dict{GEl, P}, Uπs::Vector{U}, dims::Vector{Int}) where
{M<:AbstractMatrix, GEl<:GroupElem, P<:perm, U<:AbstractMatrix}
2018-09-05 09:18:38 +02:00
lU = length(Uπs)
transfP = [dims[π].*Uπs[π]*Ps[π]*Uπs[π]' for π in 1:lU]
tmp = [zeros(Float64, size(first(transfP))) for _ in 1:lU]
2019-04-12 23:18:48 +02:00
2019-01-14 17:46:13 +01:00
Threads.@threads for π in 1:lU
tmp[π] = perm_avg(tmp[π], transfP[π], values(preps))
2018-09-05 09:18:38 +02:00
end
2018-08-20 03:45:50 +02:00
2018-11-24 15:02:28 +01:00
recP = sum(tmp)./length(preps)
2018-09-05 09:18:38 +02:00
return recP
2018-08-20 03:45:50 +02:00
end
function perm_avg(result, P, perms)
lp = length(first(perms).d)
for p in perms
# result .+= view(P, p.d, p.d)
@inbounds for j in 1:lp
k = p[j]
for i in 1:lp
result[i,j] += P[p[i], k]
end
end
end
return result
end
###############################################################################
#
# Low-level solve
#
###############################################################################
2019-04-12 23:18:48 +02:00
function setwarmstart_scs!(m::JuMP.Model, warmstart)
solver_name(m) == "SCS" || throw("warmstarting defined only for SCS!")
primal, dual, slack = warmstart
m.moi_backend.optimizer.model.optimizer.sol.primal = primal
m.moi_backend.optimizer.model.optimizer.sol.dual = dual
m.moi_backend.optimizer.model.optimizer.sol.slack = slack
return m
end
2017-03-13 14:49:55 +01:00
function getwarmstart_scs(m::JuMP.Model)
solver_name(m) == "SCS" || return (primal=Float64[], dual=Float64[], slack=Float64[])
warmstart = (
primal = m.moi_backend.optimizer.model.optimizer.sol.primal,
dual = m.moi_backend.optimizer.model.optimizer.sol.dual,
slack = m.moi_backend.optimizer.model.optimizer.sol.slack
)
return warmstart
end
2019-04-12 23:18:48 +02:00
function solve(m::JuMP.Model, with_optimizer::JuMP.OptimizerFactory, warmstart=nothing)
2019-04-12 23:18:48 +02:00
set_optimizer(m, with_optimizer)
MOIU.attach_optimizer(m)
2019-04-12 23:18:48 +02:00
2018-01-02 02:52:45 +01:00
if warmstart != nothing
2019-04-12 23:18:48 +02:00
setwarmstart_scs!(m, warmstart)
2018-01-02 02:52:45 +01:00
end
2017-03-16 09:35:32 +01:00
optimize!(m)
status = termination_status(m)
2017-03-13 14:49:55 +01:00
return status, getwarmstart_scs(m)
2017-03-13 14:49:55 +01:00
end
2018-01-01 23:59:31 +01:00
function solve(solverlog::String, m::JuMP.Model, with_optimizer::JuMP.OptimizerFactory, warmstart=nothing)
isdir(dirname(solverlog)) || mkpath(dirname(solverlog))
2019-01-11 06:32:09 +01:00
Base.flush(Base.stdout)
status, warmstart = open(solverlog, "a+") do logfile
redirect_stdout(logfile) do
status, warmstart = PropertyT.solve(m, with_optimizer, warmstart)
Base.Libc.flush_cstdio()
status, warmstart
2018-01-01 23:59:31 +01:00
end
end
return status, warmstart
2018-01-01 23:59:31 +01:00
end