10. 模型持久化#
持久化方法 |
优点 |
风险 / 缺点 |
|---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
在训练完 scikit-learn 模型后,需要有一种方法来持久化模型以便将来使用,而无需重新训练。根据您的用例,有几种不同的方法可以持久化 scikit-learn 模型,在这里我们帮助您选择最适合您的方法。为了做出决定,您需要回答以下问题
持久化后您是否需要Python对象,还是只需要持久化以服务模型并从中获取预测?
如果您只需要服务模型,并且不需要对Python对象本身进行进一步调查,那么 ONNX 可能最适合您。请注意,并非所有模型都受ONNX支持。
如果 ONNX 不适合您的用例,那么下一个问题是
您是否绝对信任模型的来源,或者对持久化模型的来源有任何安全顾虑?
如果您有安全顾虑,那么您应该考虑使用 skops.io,它会返回 Python 对象,但与基于 pickle 的持久化解决方案不同,加载持久化模型不会自动允许任意代码执行。请注意,这需要手动检查持久化文件,而 skops.io 允许您这样做。
其他解决方案假设您绝对信任要加载的文件源,因为它们在加载持久化文件时都容易受到任意代码执行的影响,因为它们都在底层使用 pickle 协议。
您是否关心模型加载的性能,以及在进程间共享磁盘上的内存映射对象是否有利?
如果是,那么您可以考虑使用 joblib。如果这不是您主要关心的问题,那么您可以使用内置的 pickle 模块。
如果是,那么您可以使用 cloudpickle,它可以序列化某些无法由 pickle 或 joblib 序列化的对象。
10.1. 工作流程概述#
在典型的工作流程中,第一步是使用 scikit-learn 和 scikit-learn 兼容库训练模型。请注意,对 scikit-learn 和第三方估计器的支持因不同的持久化方法而异。
10.1.1. 训练和持久化模型#
创建适当的模型取决于您的用例。例如,这里我们使用鸢尾花数据集训练一个 sklearn.ensemble.HistGradientBoostingClassifier
>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()
一旦模型训练完成,您就可以使用您所需的方法对其进行持久化,然后您可以在独立的环境中加载模型并根据输入数据获取预测。这里根据您持久化和计划服务模型的方式,有两条主要路径
ONNX:您需要一个
ONNX运行时和一个安装了适当依赖项的环境来加载模型并使用运行时获取预测。这个环境可以是最小的,甚至不一定需要安装 Python 来加载模型和计算预测。另请注意,onnxruntime通常比 Python 需要更少的 RAM 来计算小型模型的预测。skops.io,pickle,joblib, cloudpickle:您需要一个安装了适当依赖项的 Python 环境来加载模型并从中获取预测。这个环境应该具有与模型训练环境相同的 **包** 和 **版本**。请注意,这些方法都不支持加载使用不同版本的 scikit-learn 训练的模型,以及可能不同版本的其他依赖项,例如numpy和scipy。另一个问题是在不同的硬件上运行持久化模型,在大多数情况下,您应该能够在不同的硬件上加载您的持久化模型。
10.2. ONNX#
ONNX,或 开放神经网络交换格式,最适合需要持久化模型,然后使用持久化工件进行预测而无需加载Python对象本身的用例。它也适用于服务环境需要精简和最小化的用例,因为 ONNX 运行时不需要 python。
ONNX 是模型的二进制序列化。它的开发旨在提高数据模型可互操作表示的可用性。它旨在促进不同机器学习框架之间数据模型的转换,并提高它们在不同计算架构上的可移植性。更多详细信息可从 ONNX 教程 中获取。为了将 scikit-learn 模型转换为 ONNX,sklearn-onnx 已经开发出来。然而,并非所有 scikit-learn 模型都受支持,它仅限于核心 scikit-learn,不支持大多数第三方估计器。可以为第三方或自定义估计器编写自定义转换器,但相关文档稀疏,可能很难做到。
使用 ONNX#
要将模型转换为 ONNX 格式,您还需要向转换器提供一些关于输入的信息,您可以在 此处 阅读更多内容
from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
f.write(onx.SerializeToString())
您可以在 Python 中加载模型并使用 ONNX 运行时获取预测
from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
10.3. skops.io#
skops.io 避免使用 pickle,仅加载具有默认或用户信任的类型和函数引用的文件。因此,它提供了比 pickle、joblib 和 cloudpickle 更安全的格式。
使用 skops#
API 与 pickle 非常相似,您可以按照 文档 中的说明,使用 skops.io.dump 和 skops.io.dumps 来持久化您的模型
import skops.io as sio
obj = sio.dump(clf, "filename.skops")
您可以使用 skops.io.load 和 skops.io.loads 将它们加载回来。但是,您需要指定您信任的类型。您可以使用 skops.io.get_untrusted_types 获取已转储对象/文件中的现有未知类型,并在检查其内容后,将其传递给加载函数
unknown_types = sio.get_untrusted_types(file="filename.skops")
# investigate the contents of unknown_types, and only load if you trust
# everything you see.
clf = sio.load("filename.skops", trusted=unknown_types)
请向 skops 问题跟踪器 报告与此格式相关的问题和功能请求。
10.4. pickle、joblib 和 cloudpickle#
这三个模块/包在底层都使用 pickle 协议,但略有不同
pickle是 Python 标准库中的一个模块。它可以序列化和反序列化任何 Python 对象,包括自定义 Python 类和对象。当处理大型机器学习模型或大型 NumPy 数组时,
joblib比pickle更高效。cloudpickle 可以序列化某些
pickle或joblib无法序列化的对象,例如用户定义的函数和 lambda 函数。例如,在使用FunctionTransformer并使用自定义函数转换数据时,可能会发生这种情况。
使用 pickle、joblib 或 cloudpickle#
根据您的用例,您可以选择这三种方法中的一种来持久化和加载您的 scikit-learn 模型,它们都遵循相同的 API
# Here you can replace pickle with joblib or cloudpickle
from pickle import dump
with open("filename.pkl", "wb") as f:
dump(clf, f, protocol=5)
建议使用 protocol=5 以减少内存使用量,并更快地存储和加载模型中作为拟合属性存储的任何大型 NumPy 数组。您也可以传递 protocol=pickle.HIGHEST_PROTOCOL,这在 Python 3.8 及更高版本(撰写本文时)中等同于 protocol=5。
之后需要时,您可以从持久化文件中加载相同的对象
# Here you can replace pickle with joblib or cloudpickle
from pickle import load
with open("filename.pkl", "rb") as f:
clf = load(f)
10.5. 安全与可维护性限制#
pickle(以及 joblib 和 cloudpickle)在设计上存在许多已记录的安全漏洞,仅当工件(即 pickle 文件)来自受信任和经过验证的来源时才应使用。您绝不应加载来自不受信任来源的 pickle 文件,就像您绝不应执行来自不受信任来源的代码一样。
另请注意,任意计算可以使用 ONNX 格式表示,因此建议在沙盒环境中部署使用 ONNX 的模型,以防止计算和内存漏洞。
另外请注意,不支持加载使用不同版本的 scikit-learn 训练的模型。虽然使用 skops.io、joblib、pickle 或 cloudpickle 保存的模型可能在其他版本中加载,但这完全不受支持且不建议。还应记住,对此类数据执行的操作可能会产生不同且意外的结果,甚至可能导致您的 Python 进程崩溃。
为了用未来版本的 scikit-learn 重建类似的模型,除了 pickled 模型之外,还应保存额外的元数据
训练数据,例如对不可变快照的引用
用于生成模型的 Python 源代码
scikit-learn 及其依赖项的版本
在训练数据上获得的交叉验证分数
这应该可以检查交叉验证分数是否与以前在同一范围内。
除了少数例外,持久化模型在相同的依赖项和 Python 版本下,应能在不同操作系统和硬件架构之间移植。如果您遇到不可移植的估计器,请在 GitHub 上提出问题。持久化模型通常使用 Docker 等容器在生产环境中部署,以冻结环境和依赖项。
如果您想了解更多关于这些问题的信息,请参考这些讲座
10.5.1. 在生产环境中复制训练环境#
如果所使用的依赖项版本在训练和生产环境中可能不同,则在使用训练模型时可能会导致意外行为和错误。为防止此类情况发生,建议在训练和生产环境中使用相同的依赖项和版本。这些传递性依赖项可以使用 pip、mamba、conda、poetry、conda-lock、pixi 等包管理工具进行固定。
并非总能将使用旧版本 scikit-learn 库及其依赖项训练的模型加载到更新的软件环境中。相反,您可能需要使用所有库的新版本重新训练模型。因此,在训练模型时,记录训练配方(例如 Python 脚本)和训练集信息,以及所有依赖项的元数据,以便能够自动重建更新软件的相同训练环境,这一点很重要。
InconsistentVersionWarning#
当估计器加载的 scikit-learn 版本与估计器被 pickled 时的版本不一致时,会引发 InconsistentVersionWarning。可以捕获此警告以获取估计器被 pickled 时的原始版本
from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)
try:
with open("model_from_previous_version.pickle", "rb") as f:
est = pickle.load(f)
except InconsistentVersionWarning as w:
print(w.original_sklearn_version)
10.5.2. 服务模型工件#
训练 scikit-learn 模型后的最后一步是服务模型。一旦训练好的模型成功加载,就可以通过它管理不同的预测请求。这可能涉及使用容器化或其他模型部署策略,根据具体规范将模型部署为 Web 服务。
10.6. 总结要点#
基于模型持久化的不同方法,每种方法的要点总结如下
ONNX:它为持久化任何机器学习或深度学习模型(除了 scikit-learn)提供统一格式,并可用于模型推理(预测)。然而,它可能导致与不同框架的兼容性问题。skops.io:经过训练的 scikit-learn 模型可以使用skops.io轻松共享并投入生产。与基于pickle的替代方法相比,它更安全,因为它不会加载任意代码,除非用户明确要求。此类代码需要打包并在目标 Python 环境中可导入。joblib:当在多个 Python 进程中使用mmap_mode="r"时,高效的内存映射技术使其在使用相同持久化模型时更快。它还提供了简单的快捷方式来压缩和解压缩持久化对象,而无需额外的代码。然而,与任何其他基于 pickle 的持久化机制一样,从不受信任的来源加载模型时,它可能会触发恶意代码的执行。pickle:它是 Python 原生的,大多数 Python 对象都可以使用pickle进行序列化和反序列化,包括自定义 Python 类和函数,只要它们定义在目标环境中可导入的包中。虽然pickle可以轻松保存和加载 scikit-learn 模型,但从不受信任的来源加载模型时可能会触发恶意代码的执行。pickle如果模型使用protocol=5持久化,在内存方面也非常高效,但它不支持内存映射。cloudpickle:它具有与
pickle和joblib相当的加载效率(不含内存映射),但提供了额外的灵活性,可以序列化自定义 Python 代码,例如 lambda 表达式和交互式定义的函数和类。它可能是持久化包含自定义 Python 组件的管道的最后手段,例如包装了在训练脚本本身或更普遍地说,在任何可导入 Python 包之外定义的函数的sklearn.preprocessing.FunctionTransformer。请注意,cloudpickle 不提供向前兼容性保证,您可能需要相同版本的 cloudpickle 以及用于定义模型的所有库的相同版本来加载持久化模型。与其他的基于 pickle 的持久化机制一样,从不受信任的来源加载模型时,它可能会触发恶意代码的执行。