9. 模型持久化#
持久化方法 |
优点 |
风险/缺点 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
训练 scikit-learn 模型后,最好能够持久化模型以备将来使用,而无需重新训练。根据您的用例,有几种不同的方法可以持久化 scikit-learn 模型,在这里我们将帮助您选择最适合您的方法。为了做出决定,您需要回答以下问题:
持久化后是否需要 Python 对象,或者您只需要持久化以服务模型并从中获得预测结果?
如果您只需要服务模型并且不需要进一步研究 Python 对象本身,那么 ONNX 可能是最适合您的选择。请注意,并非所有模型都受 ONNX 支持。
如果 ONNX 不适合您的用例,下一个问题是:
您是否完全信任模型的来源,或者是否存在关于持久化模型来源的任何安全问题?
如果您存在安全问题,则应考虑使用 skops.io,它会将 Python 对象返回给您,但与基于 pickle
的持久化解决方案不同,加载持久化模型不会自动允许执行任意代码。请注意,这需要手动检查持久化文件,而 skops.io
允许您这样做。
其他解决方案都假定您完全信任要加载的文件的来源,因为它们在加载持久化文件时都容易受到任意代码执行的影响,因为它们都在底层使用 pickle 协议。
您是否关心加载模型的性能,以及在进程间共享磁盘上内存映射的对象是否有益?
如果是,则可以考虑使用 joblib。如果这不是您主要关注的问题,则可以使用内置的 pickle
模块。
如果是,则可以使用 cloudpickle,它可以序列化某些无法被 pickle
或 joblib
序列化的对象。
9.1. 工作流程概述#
在典型的流程中,第一步是使用 scikit-learn 和与 scikit-learn 兼容的库来训练模型。请注意,对 scikit-learn 和第三方估计器的支持在不同的持久化方法中有所不同。
9.1.1. 训练和持久化模型#
创建合适的模型取决于您的用例。例如,这里我们使用 sklearn.ensemble.HistGradientBoostingClassifier
对 iris 数据集进行训练
>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()
模型训练完成后,您可以使用所需的方法将其持久化,然后可以在单独的环境中加载模型,并根据输入数据获得预测结果。这里有两个主要路径,取决于您如何持久化和计划提供模型服务
ONNX:您需要一个
ONNX
运行时和一个安装了适当依赖项的环境来加载模型并使用运行时获取预测结果。此环境可以是精简的,甚至不需要安装 Python 即可加载模型并计算预测结果。另请注意,onnxruntime
通常比 Python 需要更少的 RAM 来计算小型模型的预测结果。skops.io
、pickle
、joblib
、cloudpickle:您需要一个安装了适当依赖项的 Python 环境来加载模型并从中获取预测结果。此环境应具有与训练模型的环境相同的 **包** 和相同的 **版本**。请注意,这些方法都不支持加载使用不同版本的 scikit-learn 训练的模型,也可能不支持不同版本的其他依赖项,例如numpy
和scipy
。另一个问题是在不同的硬件上运行持久化的模型,在大多数情况下,您应该能够在不同的硬件上加载持久化的模型。
9.2. ONNX#
ONNX
,或 开放神经网络交换 格式最适合需要持久化模型,然后使用持久化工件获取预测结果而无需加载 Python 对象本身的用例。它在服务环境需要精简和最小化的情况下也很有用,因为 ONNX
运行时不需要 python
。
ONNX
是模型的二进制序列化。它旨在提高数据模型可互操作表示的可用性。它旨在促进数据模型在不同机器学习框架之间的转换,并提高其在不同计算架构上的可移植性。更多详细信息可在 ONNX 教程 中找到。要将 scikit-learn 模型转换为 ONNX
,已开发出 sklearn-onnx。但是,并非所有 scikit-learn 模型都受支持,它仅限于核心 scikit-learn,并且不支持大多数第三方估计器。可以为第三方或自定义估计器编写自定义转换器,但是相关的文档很少,这样做可能具有挑战性。
使用 ONNX#
要将模型转换为 ONNX
格式,您还需要向转换器提供一些关于输入的信息,您可以 在这里 阅读更多信息
from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
f.write(onx.SerializeToString())
您可以在 Python 中加载模型并使用 ONNX
运行时获取预测结果
from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
9.3. skops.io
#
skops.io
避免使用 pickle
,并且只加载其类型和函数引用受默认或用户信任的文件。因此,它提供了比 pickle
、joblib
和 cloudpickle 更安全的格式。
使用 skops#
该 API 与 pickle
非常相似,您可以按照 文档 中的说明使用 skops.io.dump
和 skops.io.dumps
持久化您的模型
import skops.io as sio
obj = sio.dump(clf, "filename.skops")
您可以使用 skops.io.load
和 skops.io.loads
将它们加载回来。但是,您需要指定您信任的类型。您可以使用 skops.io.get_untrusted_types
获取转储对象/文件中现有的未知类型,并在检查其内容后将其传递给加载函数。
unknown_types = sio.get_untrusted_types(file="filename.skops")
# investigate the contents of unknown_types, and only load if you trust
# everything you see.
clf = sio.load("filename.skops", trusted=unknown_types)
请在 skops 问题追踪器 上报告与该格式相关的错误和功能请求。
9.4. pickle
、joblib
和 cloudpickle
#
这三个模块/包在底层使用 pickle
协议,但略有不同。
pickle
是 Python 标准库中的一个模块。它可以序列化和反序列化任何 Python 对象,包括自定义 Python 类和对象。joblib
在处理大型机器学习模型或大型 numpy 数组时比pickle
更高效。cloudpickle 可以序列化某些无法被
pickle
或joblib
序列化的对象,例如用户定义的函数和 lambda 函数。例如,当使用FunctionTransformer
并使用自定义函数转换数据时,就会发生这种情况。
使用 pickle
、joblib
或 cloudpickle
#
根据您的用例,您可以选择这三种方法之一来持久化和加载您的 scikit-learn 模型,它们都遵循相同的 API。
# Here you can replace pickle with joblib or cloudpickle
from pickle import dump
with open("filename.pkl", "wb") as f:
dump(clf, f, protocol=5)
建议使用 protocol=5
来减少内存使用量,并加快存储和加载任何作为模型拟合属性存储的大型 NumPy 数组的速度。您也可以传递 protocol=pickle.HIGHEST_PROTOCOL
,这在 Python 3.8 及更高版本中与 protocol=5
等效(在撰写本文时)。
之后需要时,您可以从持久化文件中加载相同的对象。
# Here you can replace pickle with joblib or cloudpickle
from pickle import load
with open("filename.pkl", "rb") as f:
clf = load(f)
9.5. 安全性和可维护性限制#
pickle
(以及 joblib
和 clouldpickle
)从设计上存在许多已记录的安全漏洞,只有在工件(即 pickle 文件)来自可信和经过验证的来源时才应使用。您永远不应该从不受信任的来源加载 pickle 文件,这就像您永远不应该执行来自不受信任来源的代码一样。
另请注意,可以使用 ONNX
格式表示任意计算,因此建议在沙盒环境中使用 ONNX
来服务模型,以防止计算和内存漏洞。
另请注意,没有支持的方法可以加载使用不同版本的 scikit-learn 训练的模型。虽然可以使用 skops.io
、joblib
、pickle
或 cloudpickle,使用一个版本的 scikit-learn 保存的模型可能会在其他版本中加载,但是,这完全不受支持,也不建议这样做。还应记住,对这些数据执行的操作可能会产生不同且意外的结果,甚至可能导致 Python 进程崩溃。
为了使用 scikit-learn 的未来版本重建类似的模型,应沿 pickle 模型一起保存其他元数据。
训练数据,例如对不可变快照的引用。
用于生成模型的 Python 源代码。
scikit-learn 及其依赖项的版本。
在训练数据上获得的交叉验证分数。
这应该可以检查交叉验证分数是否与之前在同一范围内。
除了少数例外情况,假设使用相同的依赖项和 Python 版本,持久化模型应该可以在操作系统和硬件架构之间移植。如果您遇到不可移植的估计器,请在 GitHub 上创建一个问题。持久化模型通常使用像 Docker 这样的容器在生产环境中部署,以冻结环境和依赖项。
如果您想了解更多关于这些问题的信息,请参阅以下演讲。
9.5.1. 在生产环境中复制训练环境#
如果使用的依赖项版本可能与训练环境和生产环境不同,则在使用训练模型时可能会导致意外行为和错误。为了防止这种情况,建议在训练环境和生产环境中使用相同的依赖项和版本。这些传递依赖项可以在 pip
、mamba
、conda
、poetry
、conda-lock
、pixi
等包管理工具的帮助下固定。
并非总是能够在更新的软件环境中加载使用旧版本 scikit-learn 库及其依赖项训练的模型。相反,您可能需要使用所有库的新版本重新训练模型。因此,在训练模型时,务必记录训练方案(例如 Python 脚本)和训练集信息以及所有依赖项的元数据,以便能够自动重建更新软件的相同训练环境。
不一致版本警告#
当使用与估计器序列化版本不一致的 scikit-learn 版本加载估计器时,会引发 InconsistentVersionWarning
警告。可以捕获此警告以获取估计器最初序列化的版本。
from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)
try:
with open("model_from_prevision_version.pickle", "rb") as f:
est = pickle.load(f)
except InconsistentVersionWarning as w:
print(w.original_sklearn_version)
9.5.2. 模型工件的服务#
训练 scikit-learn 模型后的最后一步是提供模型服务。成功加载训练好的模型后,可以提供服务来管理不同的预测请求。这可能涉及根据规范使用容器化或其他模型部署策略将模型部署为 Web 服务。
9.6. 关键点总结#
基于不同的模型持久化方法,每种方法的关键点可以总结如下:
ONNX
:它提供了一种统一的格式来持久化任何机器学习或深度学习模型(scikit-learn 除外),并可用于模型推理(预测)。但是,它可能会导致与不同框架的兼容性问题。skops.io
:使用skops.io
可以轻松共享和投入生产已训练的 scikit-learn 模型。与基于pickle
的替代方法相比,它更安全,因为它不会加载任意代码,除非用户明确要求。此类代码需要打包并在目标 Python 环境中可导入。joblib
:使用mmap_mode="r"
时,高效的内存映射技术在多个 Python 进程中使用相同的持久化模型时使其更快。它还提供简单的快捷方式来压缩和解压缩持久化对象,无需额外的代码。但是,与任何其他基于 pickle 的持久化机制一样,它可能会在从不受信任的来源加载模型时触发恶意代码的执行。pickle
:它是 Python 的原生库,大多数 Python 对象都可以使用pickle
进行序列化和反序列化,包括自定义 Python 类和函数,只要它们在目标环境中可导入的包中定义即可。虽然pickle
可用于轻松保存和加载 scikit-learn 模型,但在从不受信任的来源加载模型时,它可能会触发恶意代码的执行。pickle
如果模型使用protocol=5
持久化,则在内存方面也可以非常高效,但它不支持内存映射。cloudpickle:它具有与
pickle
和joblib
(无内存映射)相当的加载效率,但提供了额外的灵活性来序列化自定义 Python 代码,例如 lambda 表达式以及交互式定义的函数和类。它可能是持久化包含自定义 Python 组件(例如包装在训练脚本本身或更一般地在任何可导入的 Python 包之外定义的函数的sklearn.preprocessing.FunctionTransformer
)的管道的最后手段。请注意,cloudpickle 没有提供前向兼容性保证,您可能需要相同版本的 cloudpickle 来加载持久化模型以及定义模型的所有库的相同版本。与其他基于 pickle 的持久化机制一样,它可能会在从不受信任的来源加载模型时触发恶意代码的执行。