From 8331159baa6b4d11a19aaadb0bc780735d466253 Mon Sep 17 00:00:00 2001 From: kalmarek Date: Thu, 22 Nov 2018 20:01:33 +0100 Subject: [PATCH] add separate average over perms to perm_avg --- src/sos_sdps.jl | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/src/sos_sdps.jl b/src/sos_sdps.jl index 71e06b3..86fa147 100644 --- a/src/sos_sdps.jl +++ b/src/sos_sdps.jl @@ -132,26 +132,34 @@ function reconstruct(Ps::Vector{M}, preps::Dict{GEl, P}, Uπs::Vector{U}, dims::Vector{Int}) where {M<:AbstractMatrix, GEl<:GroupElem, P<:perm, U<:AbstractMatrix} - l = length(Uπs) - transfP = [dims[π].*Uπs[π]*Ps[π]*Uπs[π]' for π in 1:l] - tmp = [zeros(Float64, size(first(transfP))) for _ in 1:l] - perms = collect(keys(preps)) - - Threads.@threads for π in 1:l - for p in perms - BLAS.axpy!(1.0, view(transfP[π], preps[p].d, preps[p].d), tmp[π]) - end + lU = length(Uπs) + transfP = [dims[π].*Uπs[π]*Ps[π]*Uπs[π]' for π in 1:lU] + tmp = [zeros(Float64, size(first(transfP))) for _ in 1:lU] + + + @time Threads.@threads for π in 1:lU + tmp[π] = perm_avg(tmp[π], transfP[π], values(preps)) end - recP = 1/length(perms) .* sum(tmp) - # for i in eachindex(recP) - # if abs(recP[i]) .< eps(eltype(recP))*100 - # recP[i] = zero(eltype(recP)) - # end - # end + @time recP = sum(tmp)./length(preps) + return recP end +function perm_avg(result, P, perms) + lp = length(first(perms).d) + for p in perms + # result .+= view(P, p.d, p.d) + @inbounds for j in 1:lp + k = p[j] + for i in 1:lp + result[i,j] += P[p[i], k] + end + end + end + return result +end + ############################################################################### # # Low-level solve