11.1. 数组 API 支持(实验性)#

数组 API 规范为所有具有 NumPy 式 API 的数组操作库定义了一个标准 API。scikit-learn 的数组 API 支持需要 array-api-compat 安装。

一些主要依赖 NumPy(而不是使用 Cython)来实现其 fitpredicttransform 方法的算法逻辑的 scikit-learn 估计器,可以配置为接受任何与数组 API 兼容的输入数据结构,并自动将操作调度到底层命名空间,而不是依赖 NumPy。

在此阶段,此支持被认为是 **实验性的**,必须如以下所述明确启用。

注意

目前,仅 cupy.array_apiarray-api-strictcupyPyTorch 已知可与 scikit-learn 的估计器一起使用。

11.1.1. 示例用法#

以下是一个代码片段示例,演示如何使用 CuPy 在 GPU 上运行 LinearDiscriminantAnalysis

>>> from sklearn.datasets import make_classification
>>> from sklearn import config_context
>>> from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
>>> import cupy

>>> X_np, y_np = make_classification(random_state=0)
>>> X_cu = cupy.asarray(X_np)
>>> y_cu = cupy.asarray(y_np)
>>> X_cu.device
<CUDA Device 0>

>>> with config_context(array_api_dispatch=True):
...     lda = LinearDiscriminantAnalysis()
...     X_trans = lda.fit_transform(X_cu, y_cu)
>>> X_trans.device
<CUDA Device 0>

模型训练完成后,作为数组的拟合属性也将来自与训练数据相同的数组 API 命名空间。例如,如果使用 CuPy 的数组 API 命名空间进行训练,那么拟合属性将位于 GPU 上。我们提供了一个实验性的 _estimator_with_converted_arrays 实用程序,它将估计器属性从数组 API 转换为 ndarray

>>> from sklearn.utils._array_api import _estimator_with_converted_arrays
>>> cupy_to_ndarray = lambda array : array.get()
>>> lda_np = _estimator_with_converted_arrays(lda, cupy_to_ndarray)
>>> X_trans = lda_np.transform(X_np)
>>> type(X_trans)
<class 'numpy.ndarray'>

11.1.1.1. PyTorch 支持#

通过设置 array_api_dispatch=True 并直接传入张量,支持 PyTorch 张量

>>> import torch
>>> X_torch = torch.asarray(X_np, device="cuda", dtype=torch.float32)
>>> y_torch = torch.asarray(y_np, device="cuda", dtype=torch.float32)

>>> with config_context(array_api_dispatch=True):
...     lda = LinearDiscriminantAnalysis()
...     X_trans = lda.fit_transform(X_torch, y_torch)
>>> type(X_trans)
<class 'torch.Tensor'>
>>> X_trans.device.type
'cuda'

11.1.2. 支持 Array API 兼容输入#

scikit-learn 中支持数组 API 兼容输入的估计器和其他工具。

11.1.2.1. 估计器#

11.1.2.2. 指标#

11.1.2.3. 工具#

预计覆盖范围会随着时间的推移而增长。请关注专门的 GitHub 上的元问题 以跟踪进度。

11.1.2.4. 返回值和拟合属性的类型#

使用与数组 API 兼容的输入调用函数或方法时,惯例是返回与输入数据具有相同数组容器类型和设备的数组值。

类似地,当估计器使用与数组 API 兼容的输入进行拟合时,拟合属性将是来自与输入相同库的数组,并存储在相同的设备上。随后,predicttransform 方法期望输入来自与传递给 fit 方法的数据相同的数组库和设备。

但是请注意,返回标量值的评分函数返回 Python 标量(通常是 float 实例),而不是数组标量值。

11.1.3. 通用估计器检查#

在估计器的标签集中添加 array_api_support 标签,以表明它支持 Array API。这将启用通用测试中专用的检查,以验证使用普通 NumPy 和 Array API 输入时估计器的结果是否相同。

要运行这些检查,您需要在测试环境中安装 array_api_compat。要运行完整的检查集,您需要安装 PyTorchCuPy 并在您的系统上安装 GPU。无法执行或缺少依赖项的检查将自动跳过。因此,使用 -v 标志运行测试以查看哪些检查被跳过非常重要。

pip install array-api-compat  # and other libraries as needed
pytest -k "array_api" -v

11.1.3.1. 关于 MPS 设备支持的说明#

在 macOS 上,PyTorch 可以使用 Metal Performance Shaders (MPS) 来访问硬件加速器(例如 M1 或 M2 芯片的内部 GPU 组件)。但是,截至撰写本文时,PyTorch 的 MPS 设备支持尚不完整。有关更多详细信息,请参阅以下 GitHub 问题

要在 PyTorch 中启用 MPS 支持,请在运行测试之前设置环境变量 PYTORCH_ENABLE_MPS_FALLBACK=1

PYTORCH_ENABLE_MPS_FALLBACK=1 pytest -k "array_api" -v

截至撰写本文时,所有 scikit-learn 测试都应该通过,但是,计算速度不一定比 CPU 设备快。