diff --git a/src/constraint_matrix.jl b/src/constraint_matrix.jl index 1e95aa1..997a037 100644 --- a/src/constraint_matrix.jl +++ b/src/constraint_matrix.jl @@ -83,26 +83,29 @@ Base.@propagate_inbounds function Base.getindex( return pos - neg end -struct NZPairsIter{T} - m::ConstraintMatrix{T} +struct NZPairsIter{T,I} + m::ConstraintMatrix{T,I} end Base.eltype(::Type{NZPairsIter{T}}) where {T} = Pair{Int,T} Base.IteratorSize(::Type{<:NZPairsIter}) = Base.SizeUnknown() # TODO: iterate over (idx=>val) pairs combining vals -function Base.iterate(itr::NZPairsIter, state::Tuple{Int,Int} = (1, 1)) +function Base.iterate( + itr::NZPairsIter, + state::Tuple{Int,Nothing} = (1, nothing), +) k = iterate(itr.m.pos, state[1]) - isnothing(k) && return iterate(itr, state[2]) + isnothing(k) && return iterate(itr, (nothing, 1)) idx, st = k - return idx => itr.m.val, (st, 1) + return idx => itr.m.val, (st, nothing) end -function Base.iterate(itr::NZPairsIter, state::Int) - k = iterate(itr.m.neg, state[1]) +function Base.iterate(itr::NZPairsIter, state::Tuple{Nothing,Int}) + k = iterate(itr.m.neg, state[2]) isnothing(k) && return nothing idx, st = k - return idx => -itr.m.val, st + return idx => -itr.m.val, (nothing, st) end """