From f3d813cb6731b5656907f755f775f2b4eccdefc8 Mon Sep 17 00:00:00 2001 From: kalmarek Date: Sun, 27 Aug 2017 18:37:09 +0200 Subject: [PATCH] reworked reconstruct_sol with axpy! --- src/OrbitDecomposition.jl | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/src/OrbitDecomposition.jl b/src/OrbitDecomposition.jl index 22031c9..1c1b1f0 100644 --- a/src/OrbitDecomposition.jl +++ b/src/OrbitDecomposition.jl @@ -166,17 +166,24 @@ end function reconstruct_sol{T<:GroupElem, S<:Nemo.perm}(preps::Dict{T, S}, aUs::Vector, aPs::Vector, adims::Vector) - s = size(first(mreps).second) - recP = zeros(Float64, s) - tmp = [zeros(Float64, s) for _ in 1:length(Us)] - ks = [(g, inv(g)) for g in keys(mreps)] - Threads.@threads for π in 1:length(Us) - for (g, invg) in ks - tmp[π] += dims[π]*mreps[g]*Us[π]*Ps[π]*Us[π]'*mreps[invg] + idx = [π for π in 1:length(aUs) if size(aUs[π], 2) != 0] + Us = aUs[idx] + Ps = aPs[idx] + dims = adims[idx]; + + l = length(Us) + transfP = [dims[π].*Us[π]*Ps[π]*Us[π]' for π in 1:l] + tmp = [zeros(Float64, size(first(transfP))) for _ in 1:l] + perms = collect(keys(preps)) + + @inbounds Threads.@threads for π in 1:l + for p in perms + BLAS.axpy!(1.0, view(transfP[π], preps[p].d, preps[p].d), tmp[π]) end end - recP += 1/length(keys(mreps)) .* sum(tmp) + + recP = 1/length(perms) .* sum(tmp) recP[abs.(recP) .< eps(eltype(recP))] = zero(eltype(recP)) return recP end