12.1. Array API 支持(实验性)#
Array API 规范定义了所有具有 NumPy 风格 API 的数组操作库的标准 API。Scikit-learn 提供了 array-api-compat 和 array-api-extra 的固定版本。
Scikit-learn 对 Array API 标准的支持要求在导入 scipy
和 scikit-learn
之前,将环境变量 SCIPY_ARRAY_API
设置为 1
。
export SCIPY_ARRAY_API=1
请注意,此环境变量仅供临时使用。有关更多详细信息,请参阅 SciPy 的 Array API 文档。
某些 scikit-learn 估计器主要依赖 NumPy(而不是 Cython)来实现其 fit
、predict
或 transform
方法的算法逻辑,可以配置为接受任何 Array API 兼容的输入数据结构,并自动将操作调度到底层命名空间,而不是依赖 NumPy。
在此阶段,此支持被**视为实验性**功能,必须按以下说明明确启用。
注意
目前,已知只有 array-api-strict
、cupy
和 PyTorch
可与 scikit-learn 的估计器协同工作。
以下视频概述了该标准的设计原则以及它如何促进数组库之间的互操作性
使用 Array API 在 GPU 上运行 Scikit-learn,由 Thomas Fan 在 PyData NYC 2023 上发布。
12.1.1. 示例用法#
这是一个代码片段示例,演示如何使用 CuPy 在 GPU 上运行 LinearDiscriminantAnalysis
>>> from sklearn.datasets import make_classification
>>> from sklearn import config_context
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
>>> import cupy
>>> X_np, y_np = make_classification(random_state=0)
>>> X_cu = cupy.asarray(X_np)
>>> y_cu = cupy.asarray(y_np)
>>> X_cu.device
<CUDA Device 0>
>>> with config_context(array_api_dispatch=True):
... lda = LinearDiscriminantAnalysis()
... X_trans = lda.fit_transform(X_cu, y_cu)
>>> X_trans.device
<CUDA Device 0>
模型训练完成后,作为数组的拟合属性也将与训练数据来自相同的 Array API 命名空间。例如,如果使用 CuPy 的 Array API 命名空间进行训练,则拟合属性将位于 GPU 上。我们提供了一个实验性的 _estimator_with_converted_arrays
工具,用于将估计器属性从 Array API 转移到 ndarray。
>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
>>> cupy_to_ndarray = lambda array : array.get()
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
>>> X_trans = lda_np.transform(X_np)
>>> type(X_trans)
<class 'numpy.ndarray'>
12.1.1.1. PyTorch 支持#
通过设置 array_api_dispatch=True
并直接传入张量,可以支持 PyTorch 张量。
>>> import torch
>>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32)
>>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)
>>> with config_context(array_api_dispatch=True):
... lda = LinearDiscriminantAnalysis()
... X_trans = lda.fit_transform(X_torch, y_torch)
>>> type(X_trans)
<class 'torch.Tensor'>
>>> X_trans.device.type
'cuda'
12.1.2. 支持 Array API
兼容输入#
scikit-learn 中支持 Array API 兼容输入的估计器和其他工具。
12.1.2.1. 估计器#
decomposition.PCA
(当svd_solver="full"
,svd_solver="randomized"
和power_iteration_normalizer="QR"
时)linear_model.Ridge
(当solver="svd"
时)discriminant_analysis.LinearDiscriminantAnalysis
(当solver="svd"
时)
12.1.2.2. 元估计器#
接受 Array API 输入的元估计器,前提是基础估计器也接受此类输入
12.1.2.3. 指标#
sklearn.metrics.cluster.entropy
sklearn.metrics.mean_poisson_deviance
(需要为 SciPy 启用 Array API 支持)sklearn.metrics.pairwise.euclidean_distances
(参见关于 float64 设备支持的注意事项)sklearn.metrics.pairwise.rbf_kernel
(参见关于 float64 设备支持的注意事项)
12.1.2.4. 工具#
覆盖范围预计会随着时间推移而增长。请关注 GitHub 上的专用元议题以跟踪进度。
12.1.2.5. 返回值和拟合属性的类型#
当使用 Array API 兼容的输入调用函数或方法时,约定是返回与输入数据具有相同数组容器类型和设备的数组值。
类似地,当估计器使用 Array API 兼容输入进行拟合时,拟合属性将是来自与输入相同库的数组,并存储在相同的设备上。随后,predict
和 transform
方法期望的输入来自与传递给 fit
方法的数据相同的数组库和设备。
但请注意,返回标量值的评分函数会返回 Python 标量(通常是 float
实例),而不是数组标量值。
12.1.3. 通用估计器检查#
将 array_api_support
标签添加到估计器的标签集中,以表明它支持 Array API。这将启用通用测试中的专用检查,以验证当使用普通 NumPy 和 Array API 输入时,估计器的结果是否相同。
要运行这些检查,您需要在测试环境中安装 array-api-strict。这使您无需 GPU 即可运行检查。要运行所有检查,您还需要安装 PyTorch、CuPy 并拥有一个 GPU。无法执行或缺少依赖项的检查将自动跳过。因此,务必使用 -v
标志运行测试,以查看哪些检查被跳过。
pip install array-api-strict # and other libraries as needed
pytest -k "array_api" -v
针对 array-api-strict
运行 scikit-learn 测试应有助于揭示大多数与通过使用模拟非 CPU 设备处理多个设备输入相关的代码问题。这有助于 Array API 相关代码的快速迭代开发和调试。
然而,为确保完全处理分配在实际 GPU 设备上的 PyTorch 或 CuPy 输入,有必要针对这些库和硬件运行测试。这可以通过使用 Google Colab 实现,或者利用我们在拉取请求上的 CI 基础设施(出于成本原因由维护人员手动触发)。
12.1.3.1. 关于 MPS 设备支持的注意事项#
在 macOS 上,PyTorch 可以使用 Metal Performance Shaders (MPS) 访问硬件加速器(例如 M1 或 M2 芯片的内部 GPU 组件)。然而,在撰写本文时,PyTorch 对 MPS 设备的支持尚不完善。有关更多详细信息,请参阅以下 GitHub 问题:
要在 PyTorch 中启用 MPS 支持,请在运行测试之前设置环境变量 PYTORCH_ENABLE_MPS_FALLBACK=1
PYTORCH_ENABLE_MPS_FALLBACK=1 pytest -k "array_api" -v
在撰写本文时,所有 scikit-learn 测试都应通过,但计算速度不一定优于 CPU 设备。
12.1.3.2. 关于 float64
设备支持的注意事项#
scikit-learn 中的某些操作将自动对浮点值执行 float64
精度操作,以防止溢出并确保正确性(例如,metrics.pairwise.euclidean_distances
)。然而,某些数组命名空间和设备组合(例如 PyTorch on MPS
,参见关于 MPS 设备支持的注意事项)不支持 float64
数据类型。在这些情况下,scikit-learn 将转而使用 float32
数据类型。这可能导致与不使用 Array API 调度或使用支持 float64
的设备相比,行为有所不同(通常是数值不稳定的结果)。