9. 模型持久化#

在训练 scikit-learn 模型后,最好有一种方法可以持久化模型以便将来使用,而无需重新训练。根据您的用例,有几种不同的方法可以持久化 scikit-learn 模型,在这里我们将帮助您确定哪种方法最适合您。为了做出决定,您需要回答以下问题

  1. 您在持久化后是否需要 Python 对象,或者您只需要持久化以便服务模型并从中获取预测?

如果您只需要服务模型,并且不需要对 Python 对象本身进行进一步调查,那么ONNX 可能是最适合您的选择。请注意,并非所有模型都受 ONNX 支持。

如果 ONNX 不适合您的用例,那么下一个问题是

  1. 您是否完全信任模型的来源,或者对持久化模型的来源是否存在任何安全问题?

如果您存在安全问题,那么您应该考虑使用skops.io,它会将 Python 对象返回给您,但与基于 pickle 的持久化解决方案不同,加载持久化模型不会自动允许任意代码执行。请注意,这需要手动检查持久化文件,skops.io 允许您这样做。

其他解决方案假设您完全信任要加载文件的来源,因为它们在加载持久化文件时都容易受到任意代码执行的影响,因为它们都在内部使用 pickle 协议。

  1. 您是否关心加载模型的性能,以及在磁盘上使用内存映射对象对进程之间共享模型是否有益?

如果是,那么您可以考虑使用joblib。如果这不是您主要关注的问题,那么您可以使用内置的pickle 模块。

  1. 您是否尝试过picklejoblib 并发现模型无法持久化?例如,当您的模型中包含用户定义函数时,就会发生这种情况。

如果是,那么您可以使用cloudpickle,它可以序列化某些无法被picklejoblib 序列化的对象。

9.1. 工作流程概述#

在典型的工作流程中,第一步是使用 scikit-learn 和与 scikit-learn 兼容的库来训练模型。请注意,对 scikit-learn 和第三方估计器的支持在不同的持久化方法之间有所不同。

9.1.1. 训练并持久化模型#

创建合适的模型取决于您的用例。例如,这里我们在 iris 数据集上训练一个sklearn.ensemble.HistGradientBoostingClassifier

>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()

模型训练完成后,您可以使用您想要的方法将其持久化,然后您可以在单独的环境中加载模型,并根据输入数据从中获取预测。这里有两个主要路径,具体取决于您如何持久化模型以及计划如何服务模型

  • ONNX: 您需要一个 ONNX 运行时和一个安装了适当依赖项的环境来加载模型并使用运行时获取预测。此环境可以是极简的,甚至不需要安装 Python 就可以加载模型并计算预测。还要注意,onnxruntime 通常比 Python 需要更少的 RAM 来从小型模型中计算预测。

  • skops.iopicklejoblibcloudpickle:您需要一个安装了相应依赖项的 Python 环境来加载模型并从中获取预测。此环境应具有与训练模型的环境相同的 **包** 和相同的 **版本**。请注意,这些方法都不支持加载使用不同版本的 scikit-learn 训练的模型,以及可能不同版本的其他依赖项,例如 numpyscipy。另一个问题是在不同的硬件上运行持久化的模型,在大多数情况下,您应该能够在不同的硬件上加载持久化的模型。

9.2. ONNX#

ONNX开放神经网络交换 格式最适合在需要持久化模型,然后使用持久化的工件来获取预测而无需加载 Python 对象本身的用例中。它在服务环境需要精简和最小时也很有用,因为 ONNX 运行时不需要 python

ONNX 是模型的二进制序列化。它旨在提高数据模型可互操作表示的可使用性。它旨在促进不同机器学习框架之间的数据模型转换,并提高它们在不同计算架构上的可移植性。更多详细信息请参阅 ONNX 教程。要将 scikit-learn 模型转换为 ONNXsklearn-onnx 已开发出来。但是,并非所有 scikit-learn 模型都受支持,它仅限于核心 scikit-learn,不支持大多数第三方估计器。可以为第三方或自定义估计器编写自定义转换器,但执行此操作的文档很少,可能具有挑战性。

使用 ONNX#

要将模型转换为 ONNX 格式,您还需要向转换器提供有关输入的一些信息,您可以在 此处 了解更多信息。

from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
    f.write(onx.SerializeToString())

您可以在 Python 中加载模型,并使用 ONNX 运行时来获取预测。

from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
    onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]

9.3. skops.io#

skops.io 避免使用 pickle,并且只加载具有类型和函数引用的文件,这些类型和函数引用默认情况下或由用户信任。因此,它提供了一种比 picklejoblibcloudpickle 更安全的格式。

使用 skops#

API 与 pickle 非常相似,您可以按照 文档 中的说明,使用 skops.io.dumpskops.io.dumps 持久化模型。

import skops.io as sio
obj = sio.dump(clf, "filename.skops")

您可以使用 skops.io.loadskops.io.loads 将它们加载回来。但是,您需要指定您信任的类型。您可以使用 skops.io.get_untrusted_types 获取转储对象/文件中现有的未知类型,并在检查其内容后将其传递给加载函数。

unknown_types = sio.get_untrusted_types(file="filename.skops")
# investigate the contents of unknown_types, and only load if you trust
# everything you see.
clf = sio.load("filename.skops", trusted=unknown_types)

请在 skops 问题跟踪器 上报告与这种格式相关的错误和功能请求。

9.4. picklejoblibcloudpickle#

这三个模块/包在幕后使用 pickle 协议,但略有不同。

  • pickle 是 Python 标准库中的一个模块。它可以序列化和反序列化任何 Python 对象,包括自定义 Python 类和对象。

  • joblib 在处理大型机器学习模型或大型 numpy 数组时比 pickle 更有效。

  • cloudpickle 可以序列化某些无法被 picklejoblib 序列化的对象,例如用户定义的函数和 lambda 函数。例如,当使用 FunctionTransformer 并使用自定义函数来转换数据时,就会发生这种情况。

使用 picklejoblibcloudpickle#

根据您的用例,您可以选择这三种方法之一来持久化和加载您的 scikit-learn 模型,它们都遵循相同的 API。

# Here you can replace pickle with joblib or cloudpickle
from pickle import dump
with open("filename.pkl", "wb") as f:
    dump(clf, f, protocol=5)

建议使用 protocol=5 来减少内存使用量,并加快存储和加载作为模型中拟合属性存储的任何大型 NumPy 数组的速度。您也可以传递 protocol=pickle.HIGHEST_PROTOCOL,它在 Python 3.8 及更高版本中等效于 protocol=5(截至撰写本文时)。

之后,您可以在需要时从持久化的文件中加载相同的对象。

# Here you can replace pickle with joblib or cloudpickle
from pickle import load
with open("filename.pkl", "rb") as f:
    clf = load(f)

9.5. 安全性和可维护性限制#

pickle(以及 joblibclouldpickle 扩展)在设计上存在许多记录在案的安全漏洞,并且仅应在工件(即 pickle 文件)来自可信且经过验证的来源时使用。您永远不应该从不可信的来源加载 pickle 文件,就像您永远不应该执行来自不可信的来源的代码一样。

另请注意,可以使用 ONNX 格式表示任意计算,因此建议在沙箱环境中使用 ONNX 来服务模型,以防范计算和内存攻击。

另请注意,没有支持的方法来加载使用不同版本的 scikit-learn 训练的模型。虽然使用 skops.iojoblibpicklecloudpickle,使用一个版本的 scikit-learn 保存的模型可能会在其他版本中加载,但是,这完全不受支持,也不建议这样做。还应牢记,对这种数据执行的操作可能会产生不同的、意外的结果,甚至可能使您的 Python 进程崩溃。

为了使用 scikit-learn 的未来版本重建类似的模型,应将其他元数据与腌制的模型一起保存。

  • 训练数据,例如对不可变快照的引用。

  • 用于生成模型的 Python 源代码。

  • scikit-learn 及其依赖项的版本。

  • 在训练数据上获得的交叉验证分数

这应该可以检查交叉验证分数是否与之前在同一范围内。

除了少数例外,持久化模型应该可以在不同的操作系统和硬件架构之间移植,假设使用相同的依赖项版本和 Python。如果您遇到不可移植的估计器,请在 GitHub 上打开一个问题。持久化模型通常使用像 Docker 这样的容器在生产环境中部署,以便冻结环境和依赖项。

如果您想了解更多关于这些问题的信息,请参考以下演讲

9.5.1. 在生产环境中复制训练环境#

如果使用的依赖项版本可能与训练到生产环境不同,则在使用训练后的模型时可能会导致意外行为和错误。为了防止这种情况,建议在训练和生产环境中使用相同的依赖项和版本。这些传递依赖项可以使用包管理工具(如 pipmambacondapoetryconda-lockpixi 等)固定。

并非总是能够在更新的软件环境中加载使用旧版本的 scikit-learn 库及其依赖项训练的模型。相反,您可能需要使用所有库的新版本重新训练模型。因此,在训练模型时,记录训练配方(例如 Python 脚本)和训练集信息以及所有依赖项的元数据非常重要,以便能够自动为更新的软件重建相同的训练环境。

InconsistentVersionWarning#

当使用与估计器被腌制时版本不一致的 scikit-learn 版本加载估计器时,将引发 InconsistentVersionWarning。可以捕获此警告以获取估计器被腌制时的原始版本

from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)

try:
    with open("model_from_prevision_version.pickle", "rb") as f:
        est = pickle.load(f)
except InconsistentVersionWarning as w:
    print(w.original_sklearn_version)

9.5.2. 提供模型工件#

训练 scikit-learn 模型后的最后一步是提供模型。成功加载训练后的模型后,可以提供它来管理不同的预测请求。这可能涉及根据规范将模型部署为使用容器化的 Web 服务或其他模型部署策略。

9.6. 总结关键点#

基于不同的模型持久化方法,每种方法的关键点可以总结如下

  • ONNX:它提供了一种统一的格式来持久化任何机器学习或深度学习模型(除了 scikit-learn),并且对模型推理(预测)很有用。但是,它可能会导致与不同框架的兼容性问题。

  • skops.io:可以使用 skops.io 轻松共享和投入生产训练好的 scikit-learn 模型。与基于 pickle 的替代方法相比,它更安全,因为它不会加载任意代码,除非用户明确要求。此类代码需要打包并在目标 Python 环境中可导入。

  • joblib:当使用 mmap_mode="r" 在多个 Python 进程中使用相同的持久化模型时,高效的内存映射技术使其更快。它还提供简便的快捷方式来压缩和解压缩持久化对象,而无需额外的代码。但是,当从不受信任的来源加载模型时,它可能会触发恶意代码的执行,就像其他基于 pickle 的持久化机制一样。

  • pickle:它是 Python 的原生方法,大多数 Python 对象可以使用 pickle 序列化和反序列化,包括自定义 Python 类和函数,只要它们在目标环境中可导入的包中定义。虽然 pickle 可用于轻松保存和加载 scikit-learn 模型,但它可能会在从不受信任的来源加载模型时触发恶意代码的执行。 pickle 在内存方面也非常高效,如果模型使用 protocol=5 持久化,但它不支持内存映射。

  • cloudpickle:它具有与 picklejoblib(没有内存映射)相当的加载效率,但提供了额外的灵活性来序列化自定义 Python 代码,例如 lambda 表达式以及交互式定义的函数和类。它可能是持久化具有自定义 Python 组件的管道(例如 sklearn.preprocessing.FunctionTransformer,它包装了在训练脚本本身或更一般地在任何可导入的 Python 包之外定义的函数)的最后手段。请注意,cloudpickle 不提供向前兼容性保证,您可能需要使用相同版本的 cloudpickle 来加载持久化模型,以及定义模型时使用的所有库的相同版本。与其他基于 pickle 的持久化机制一样,它可能会在从不受信任的来源加载模型时触发恶意代码的执行。