3RNN/Lib/site-packages/sklearn/utils/_chunking.py
2024-05-26 19:49:15 +02:00

176 lines
5.2 KiB
Python

import warnings
from itertools import islice
from numbers import Integral
import numpy as np
from .._config import get_config
from ._param_validation import Interval, validate_params
def chunk_generator(gen, chunksize):
"""Chunk generator, ``gen`` into lists of length ``chunksize``. The last
chunk may have a length less than ``chunksize``."""
while True:
chunk = list(islice(gen, chunksize))
if chunk:
yield chunk
else:
return
@validate_params(
{
"n": [Interval(Integral, 1, None, closed="left")],
"batch_size": [Interval(Integral, 1, None, closed="left")],
"min_batch_size": [Interval(Integral, 0, None, closed="left")],
},
prefer_skip_nested_validation=True,
)
def gen_batches(n, batch_size, *, min_batch_size=0):
"""Generator to create slices containing `batch_size` elements from 0 to `n`.
The last slice may contain less than `batch_size` elements, when
`batch_size` does not divide `n`.
Parameters
----------
n : int
Size of the sequence.
batch_size : int
Number of elements in each batch.
min_batch_size : int, default=0
Minimum number of elements in each batch.
Yields
------
slice of `batch_size` elements
See Also
--------
gen_even_slices: Generator to create n_packs slices going up to n.
Examples
--------
>>> from sklearn.utils import gen_batches
>>> list(gen_batches(7, 3))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(6, 3))
[slice(0, 3, None), slice(3, 6, None)]
>>> list(gen_batches(2, 3))
[slice(0, 2, None)]
>>> list(gen_batches(7, 3, min_batch_size=0))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(7, 3, min_batch_size=2))
[slice(0, 3, None), slice(3, 7, None)]
"""
start = 0
for _ in range(int(n // batch_size)):
end = start + batch_size
if end + min_batch_size > n:
continue
yield slice(start, end)
start = end
if start < n:
yield slice(start, n)
@validate_params(
{
"n": [Interval(Integral, 1, None, closed="left")],
"n_packs": [Interval(Integral, 1, None, closed="left")],
"n_samples": [Interval(Integral, 1, None, closed="left"), None],
},
prefer_skip_nested_validation=True,
)
def gen_even_slices(n, n_packs, *, n_samples=None):
"""Generator to create `n_packs` evenly spaced slices going up to `n`.
If `n_packs` does not divide `n`, except for the first `n % n_packs`
slices, remaining slices may contain fewer elements.
Parameters
----------
n : int
Size of the sequence.
n_packs : int
Number of slices to generate.
n_samples : int, default=None
Number of samples. Pass `n_samples` when the slices are to be used for
sparse matrix indexing; slicing off-the-end raises an exception, while
it works for NumPy arrays.
Yields
------
`slice` representing a set of indices from 0 to n.
See Also
--------
gen_batches: Generator to create slices containing batch_size elements
from 0 to n.
Examples
--------
>>> from sklearn.utils import gen_even_slices
>>> list(gen_even_slices(10, 1))
[slice(0, 10, None)]
>>> list(gen_even_slices(10, 10))
[slice(0, 1, None), slice(1, 2, None), ..., slice(9, 10, None)]
>>> list(gen_even_slices(10, 5))
[slice(0, 2, None), slice(2, 4, None), ..., slice(8, 10, None)]
>>> list(gen_even_slices(10, 3))
[slice(0, 4, None), slice(4, 7, None), slice(7, 10, None)]
"""
start = 0
for pack_num in range(n_packs):
this_n = n // n_packs
if pack_num < n % n_packs:
this_n += 1
if this_n > 0:
end = start + this_n
if n_samples is not None:
end = min(n_samples, end)
yield slice(start, end, None)
start = end
def get_chunk_n_rows(row_bytes, *, max_n_rows=None, working_memory=None):
"""Calculate how many rows can be processed within `working_memory`.
Parameters
----------
row_bytes : int
The expected number of bytes of memory that will be consumed
during the processing of each row.
max_n_rows : int, default=None
The maximum return value.
working_memory : int or float, default=None
The number of rows to fit inside this number of MiB will be
returned. When None (default), the value of
``sklearn.get_config()['working_memory']`` is used.
Returns
-------
int
The number of rows which can be processed within `working_memory`.
Warns
-----
Issues a UserWarning if `row_bytes exceeds `working_memory` MiB.
"""
if working_memory is None:
working_memory = get_config()["working_memory"]
chunk_n_rows = int(working_memory * (2**20) // row_bytes)
if max_n_rows is not None:
chunk_n_rows = min(chunk_n_rows, max_n_rows)
if chunk_n_rows < 1:
warnings.warn(
"Could not adhere to working_memory config. "
"Currently %.0fMiB, %.0fMiB required."
% (working_memory, np.ceil(row_bytes * 2**-20))
)
chunk_n_rows = 1
return chunk_n_rows