diff --git a/src/mtables.jl b/src/mtables.jl index ad72c23..dbc362b 100644 --- a/src/mtables.jl +++ b/src/mtables.jl @@ -71,3 +71,83 @@ Base.@propagate_inbounds function Base.getindex(m::MTable, i::Integer, j::Intege return m.table[i, j] end +## CachedMTables + +struct CachedMTable{T,I,B<:Basis{T,I},M,Twisted} <: AbstractMTable{I,Twisted} + basis::B + table::M +end + +function CachedMTable{Tw}( + basis::AbstractBasis{T,I}; + table_size = (l = length(basis); (l, l)), +) where {Tw,T,I} + return CachedMTable{Tw}(basis, zeros(I, table_size)) +end + +function CachedMTable{Tw}(basis::AbstractBasis{T,I}, mt::AbstractMatrix{I}) where {Tw,T,I} + return CachedMTable{T,I,typeof(basis),typeof(mt),Tw}(basis, mt) +end + +function CachedMTable{Tw}( + basis::AbstractVector; + table_size = (l = length(basis); (l, l)), +) where {Tw} + b = Basis{UInt32}(basis) + cmt = zeros(UInt32, table_size) + return CachedMTable{Tw}(b, cmt) +end + +basis(m::CachedMTable) = m.basis + +Base.@propagate_inbounds function Base.getindex(cmt::CachedMTable, i::Integer, j::Integer) + cache!(cmt, i, j) + return cmt.table[i, j] +end + +Base.@propagate_inbounds function cache!(cmt::CachedMTable, i::Integer, j::Integer) + @boundscheck checkbounds(cmt, i, j) + if !_iscached(cmt, i, j) + b = basis(cmt) + g, h = b[i], b[j] + @debug "Caching $i, $j" g h + gh = _product(cmt, g, h) + gh in b || throw(ProductNotDefined(i, j, "$g · $h = $gh")) + cmt.table[i, j] = b[gh] + end + return cmt +end + +Base.@propagate_inbounds function cache!( + cmt::CachedMTable{T,I,B,M,false}, + suppX::AbstractVector{<:Integer}, + suppY::AbstractVector{<:Integer}, +) where {T,I,B,M} + Threads.@threads for j in suppY + for i in suppX + if !_iscached(cmt, i, j) + cache!(cmt, i, j) + end + end + end + return cmt +end + +Base.@propagate_inbounds function cache!( + cmt::CachedMTable{T,I,B,M,true}, + suppX::AbstractVector{<:Integer}, + suppY::AbstractVector{<:Integer}, +) where {T,I,B,M} + b = basis(cmt) + Threads.@threads for i in suppX + g = star(b[i]) + for j in suppY + if !_iscached(cmt, i, j) + gh = _product(Val(false), g, b[j]) + gh in b || throw(ProductNotDefined(i, j, "$g · $h = $gh")) + cmt.table[i, j] = b[gh] + end + end + end + return cmt +end diff --git a/test/cachedmtables.jl b/test/cachedmtables.jl new file mode 100644 index 0000000..7a08398 --- /dev/null +++ b/test/cachedmtables.jl @@ -0,0 +1,32 @@ +@testset "CachedMTable" begin + elts = collect(SymmetricGroup(3)) + b = New.Basis{UInt8}(elts) + + mstr = New.CachedMTable{false}(b) + @test all(iszero, mstr.table) + New.cache!(mstr, 1, 2) + @test mstr.table[1, 2] == 2 + @test mstr.table[1, 1] == 0 + + @test mstr.table[2, 3] == 0 + @test mstr[2, 3] == 4 + @test mstr.table[2, 3] == 4 + + @test b[b[2]*b[3]] == 4 + + + @test mstr.table[1, 3] == 0 + @test mstr.table[1, 4] == 0 + New.cache!(mstr, [1], [3, 4]) + @test mstr.table[1, 3] != 0 + @test mstr.table[1, 4] != 0 + + + tmstr = New.CachedMTable{true}(b) + + @test all(iszero, tmstr.table) + @test tmstr[1, 2] == 2 + @test tmstr[2, 3] == 4 + @test tmstr[3, 2] == b[inv(b[3])*b[2]] +end + diff --git a/test/runtests.jl b/test/runtests.jl index 70b0b53..5e7a718 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,6 +4,14 @@ using AbstractAlgebra using GroupRings using SparseArrays +using GroupRings.New + +@testset "StarAlgebras" begin + + New.star(p::Generic.Perm) = inv(p) + + include("cachedmtables.jl") +end @testset "GroupRings" begin @testset "Constructors: SymmetricGroup" begin