gen_batches#

sklearn.utils.gen_batches(n, batch_size, *, min_batch_size=0)[source]#

生成器,用于创建包含 batch_size 个元素的切片,范围从 0 到 n

batch_size 不能整除 n 时,最后一个切片可能包含少于 batch_size 个元素。

参数:
nint

序列的大小。

batch_sizeint

每个批次中的元素数量。

min_batch_sizeint, 默认为 0

每个批次中的最小元素数量。

返回:
包含 batch_size 个元素的切片

另请参阅

gen_even_slices

生成器,用于创建最多包含 n 个元素的 n_packs 切片。

示例

>>> from sklearn.utils import gen_batches
>>> list(gen_batches(7, 3))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(6, 3))
[slice(0, 3, None), slice(3, 6, None)]
>>> list(gen_batches(2, 3))
[slice(0, 2, None)]
>>> list(gen_batches(7, 3, min_batch_size=0))
[slice(0, 3, None), slice(3, 6, None), slice(6, 7, None)]
>>> list(gen_batches(7, 3, min_batch_size=2))
[slice(0, 3, None), slice(3, 7, None)]