开发者实用工具#

Scikit-learn 包含许多有助于开发的实用工具。这些工具位于 sklearn.utils 中,并涵盖多个类别。以下所有函数和类都位于 sklearn.utils 模块中。

警告

这些实用工具旨在 scikit-learn 包内部使用。它们不保证在 scikit-learn 版本之间保持稳定。特别是,随着 scikit-learn 依赖项的发展,反向移植(Backports)将被移除。

验证工具#

这些工具用于检查和验证输入。当您编写接受数组、矩阵或稀疏矩阵作为参数的函数时,应在适用时使用以下工具。

  • assert_all_finite: 如果数组包含 NaNs 或 Infs,则抛出错误。

  • as_float_array: 将输入转换为浮点数组。如果传入稀疏矩阵,将返回稀疏矩阵。

  • check_array: 检查输入是否为二维数组,对稀疏矩阵引发错误。可以可选地指定允许的稀疏矩阵格式,以及允许一维或N维数组。默认情况下调用 assert_all_finite

  • check_X_y: 检查 X 和 y 的长度是否一致,对 X 调用 check_array,对 y 调用 column_or_1d。对于多标签分类或多目标回归,请指定 multi_output=True,在这种情况下将对 y 调用 check_array。

  • indexable: 检查所有输入数组是否长度一致,并且可以使用 safe_index 进行切片或索引。这用于验证交叉验证的输入。

  • validation.check_memory 检查输入是否类似于 joblib.Memory,这意味着它可以转换为 sklearn.utils.Memory 实例(通常是表示 cachedir 的字符串)或具有相同的接口。

如果您的代码依赖于随机数生成器,则不应使用诸如 numpy.random.randomnumpy.random.normal 等函数。这种方法可能导致单元测试中的可重复性问题。相反,应使用一个 numpy.random.RandomState 对象,该对象是根据传递给类或函数的 random_state 参数构建的。然后,可以使用下面的函数 check_random_state 来创建随机数生成器对象。

  • check_random_state: 根据参数 random_state 创建一个 np.random.RandomState 对象。

    • 如果 random_stateNonenp.random,则返回一个随机初始化的 RandomState 对象。

    • 如果 random_state 是一个整数,则将其用作新的 RandomState 对象的种子。

    • 如果 random_state 是一个 RandomState 对象,则直接传递该对象。

例如

>>> from sklearn.utils import check_random_state
>>> random_state = 0
>>> random_state = check_random_state(random_state)
>>> random_state.rand(4)
array([0.5488135 , 0.71518937, 0.60276338, 0.54488318])

在开发您自己的 scikit-learn 兼容估算器时,可以使用以下辅助函数。

  • validation.check_is_fitted: 在调用 transformpredict 或类似方法之前,检查估算器是否已拟合。此辅助函数允许在所有估算器中引发标准化的错误消息。

  • validation.has_fit_parameter: 检查给定参数是否受给定估算器的 fit 方法支持。

高效线性代数与数组操作#

  • extmath.randomized_range_finder: 构建一个正交矩阵,其范围近似于输入的范围。这在下面的 extmath.randomized_svd 中使用。

  • extmath.randomized_svd: 计算 k 截断的随机 SVD。该算法通过随机化来加速计算,从而找到精确的截断奇异值分解。在大型矩阵上,如果您只想提取少量分量,该算法特别快。

  • arrayfuncs.cholesky_delete: (在 lars_path 中使用)从乔利斯基分解中移除一个项。

  • arrayfuncs.min_pos: (在 sklearn.linear_model.least_angle 中使用)查找数组中正值的最小值。

  • extmath.fast_logdet: 高效计算矩阵行列式的对数。

  • extmath.density: 高效计算稀疏向量的密度。

  • extmath.safe_sparse_dot: 点积函数,能正确处理 scipy.sparse 输入。如果输入是稠密的,它等同于 numpy.dot

  • extmath.weighted_mode: scipy.stats.mode 的扩展,允许每个项具有实值权重。

  • resample: 以一致的方式对数组或稀疏矩阵进行重采样。在下面的 shuffle 中使用。

  • shuffle: 以一致的方式打乱数组或稀疏矩阵。在 k_means 中使用。

高效随机抽样#

稀疏矩阵的高效例程#

cython 模块 sklearn.utils.sparsefuncs 包含用于高效处理 scipy.sparse 数据的编译扩展。

图例程#

  • graph.single_source_shortest_path_length: (目前未在 scikit-learn 中使用)返回从单个源到图上所有连接节点的最短路径。代码改编自 networkx。如果未来需要再次使用,使用 graph_shortest_path 中的一次 Dijkstra 算法迭代将快得多。

测试函数#

  • discovery.all_estimators : 返回 scikit-learn 中所有估算器的列表,用于测试其行为和接口的一致性。

  • discovery.all_displays : 返回 scikit-learn 中所有显示器(与绘图 API 相关)的列表,用于测试其行为和接口的一致性。

  • discovery.all_functions : 返回 scikit-learn 中所有函数的列表,用于测试其行为和接口的一致性。

多类和多标签实用函数#

辅助函数#

  • gen_even_slices: 用于创建最多 n 个切片的 n 组切片的生成器。在 dict_learningk_means 中使用。

  • gen_batches: 用于创建包含从 0 到 n 的批次大小元素的切片的生成器。

  • safe_mask: 辅助函数,用于将掩码转换为 numpy 数组或 scipy 稀疏矩阵所需的格式(稀疏矩阵仅支持整数索引,而 numpy 数组支持布尔掩码和整数索引)。

  • safe_sqr: 用于统一对类数组、矩阵和稀疏矩阵进行平方 (**2) 的辅助函数。

哈希函数#

  • murmurhash3_32 提供了 MurmurHash3_x86_32 C++ 非加密哈希函数的 Python 封装。该哈希函数适用于实现查找表、布隆过滤器、Count Min Sketch、特征哈希和隐式定义的稀疏随机投影。

    >>> from sklearn.utils import murmurhash3_32
    >>> murmurhash3_32("some feature", seed=0) == -384616559
    True
    
    >>> murmurhash3_32("some feature", seed=0, positive=True) == 3910350737
    True
    

    模块 sklearn.utils.murmurhash 也可以从其他 cython 模块“cimport”导入,以便利用 MurmurHash 的高性能,同时避免 Python 解释器的开销。

警告和异常#

  • deprecated: 用于将函数或类标记为已弃用的装饰器。

  • ConvergenceWarning: 用于捕获收敛问题的自定义警告。在 sklearn.covariance.graphical_lasso 中使用。