10. 模型持久化#
持久化方法 |
优点 |
风险/缺点 |
---|---|---|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
在训练scikit-learn模型后,最好有一种方法可以持久化模型以供将来使用,而无需重新训练。根据您的用例,有几种不同的方法可以持久化scikit-learn模型,我们在此帮助您决定哪种最适合您。为了做出决定,您需要回答以下问题
持久化后您是否需要Python对象,还是只需要持久化模型以便提供服务并从中获取预测?
如果您只需要部署模型,并且不需要对Python对象本身进行进一步的探究,那么ONNX可能最适合您。请注意,并非所有模型都受ONNX支持。
如果ONNX不适合您的用例,那么下一个问题是
您是否绝对信任模型的来源,或者对持久化模型的来源有任何安全顾虑?
如果您有安全顾虑,那么您应该考虑使用skops.io,它会返回Python对象,但与基于pickle
的持久化解决方案不同,加载持久化模型不会自动允许任意代码执行。请注意,这需要手动检查持久化文件,而skops.io
允许您这样做。
其他解决方案假定您绝对信任要加载的文件来源,因为它们在底层都使用pickle协议,所以在加载持久化文件时都容易受到任意代码执行的攻击。
您是否关心模型加载的性能,以及在进程间共享磁盘上的内存映射对象是否有益?
如果关心,那么您可以考虑使用joblib。如果这不是您主要关注的问题,那么您可以使用内置的pickle
模块。
如果是,那么您可以使用cloudpickle,它可以序列化某些pickle
或joblib
无法序列化的对象。
10.1. 工作流程概述#
在典型的工作流程中,第一步是使用scikit-learn和scikit-learn兼容库训练模型。请注意,对scikit-learn和第三方估计器的支持在不同的持久化方法中有所不同。
10.1.1. 训练并持久化模型#
创建合适的模型取决于您的用例。例如,我们在此使用鸢尾花数据集训练一个sklearn.ensemble.HistGradientBoostingClassifier
>>> from sklearn import ensemble
>>> from sklearn import datasets
>>> clf = ensemble.HistGradientBoostingClassifier()
>>> X, y = datasets.load_iris(return_X_y=True)
>>> clf.fit(X, y)
HistGradientBoostingClassifier()
模型训练完成后,您可以使用所需的方法对其进行持久化,然后可以在独立的环境中加载模型,并根据输入数据获取预测。根据您持久化和计划部署模型的方式,主要有两种途径
ONNX: 您需要一个
ONNX
运行时以及一个安装了相应依赖的环境来加载模型并使用运行时获取预测。此环境可以非常精简,甚至不一定需要安装Python即可加载模型并计算预测。另请注意,onnxruntime
通常比Python需要少得多的RAM来计算小型模型的预测。skops.io
、pickle
、joblib
、cloudpickle:您需要一个安装了相应依赖的Python环境来加载模型并从中获取预测。此环境应与训练模型的环境具有相同的包和版本。请注意,这些方法都不支持加载使用不同版本的scikit-learn训练的模型,也可能不支持不同版本的其他依赖项,例如numpy
和scipy
。另一个问题是在不同的硬件上运行持久化模型,在大多数情况下,您应该能够在不同的硬件上加载持久化模型。
10.2. ONNX#
ONNX
,即开放神经网络交换格式,最适用于需要持久化模型,然后使用持久化的工件获取预测而无需加载Python对象本身的用例。它在部署环境需要精简和最小化的场景中也很有用,因为ONNX
运行时不需要python
。
ONNX
是模型的二进制序列化。它的开发旨在提高数据模型互操作表示的可用性。它旨在促进不同机器学习框架之间的数据模型转换,并提高它们在不同计算架构上的可移植性。更多详细信息可从ONNX教程获取。为了将scikit-learn模型转换为ONNX
,sklearn-onnx已经开发。然而,并非所有scikit-learn模型都受支持,它仅限于核心scikit-learn,不支持大多数第三方估计器。可以为第三方或自定义估计器编写自定义转换器,但相关文档很少,这样做可能具有挑战性。
使用ONNX#
要将模型转换为ONNX
格式,您还需要向转换器提供一些关于输入的信息,您可以在此处阅读更多内容
from skl2onnx import to_onnx
onx = to_onnx(clf, X[:1].astype(numpy.float32), target_opset=12)
with open("filename.onnx", "wb") as f:
f.write(onx.SerializeToString())
您可以在Python中加载模型并使用ONNX
运行时获取预测
from onnxruntime import InferenceSession
with open("filename.onnx", "rb") as f:
onx = f.read()
sess = InferenceSession(onx, providers=["CPUExecutionProvider"])
pred_ort = sess.run(None, {"X": X_test.astype(numpy.float32)})[0]
10.3. skops.io
#
skops.io
避免使用pickle
,仅加载默认或用户信任的类型和函数引用的文件。因此,它比基于pickle
、joblib
和cloudpickle的格式更安全。
使用skops#
API与pickle
非常相似,您可以使用文档中解释的方法,通过skops.io.dump
和skops.io.dumps
持久化您的模型
import skops.io as sio
obj = sio.dump(clf, "filename.skops")
您可以使用skops.io.load
和skops.io.loads
将其加载回来。但是,您需要指定您信任的类型。您可以使用skops.io.get_untrusted_types
获取已转储对象/文件中的现有未知类型,并在检查其内容后,将其传递给加载函数
unknown_types = sio.get_untrusted_types(file="filename.skops")
# investigate the contents of unknown_types, and only load if you trust
# everything you see.
clf = sio.load("filename.skops", trusted=unknown_types)
请在skops问题追踪器上报告与此格式相关的问题和功能请求。
10.4. pickle
、joblib
和cloudpickle
#
这三个模块/包底层都使用pickle
协议,但有一些细微差别
pickle
是Python标准库中的一个模块。它可以序列化和反序列化任何Python对象,包括自定义Python类和对象。当处理大型机器学习模型或大型numpy数组时,
joblib
比pickle
更高效。cloudpickle可以序列化某些
pickle
或joblib
无法序列化的对象,例如用户定义的函数和lambda函数。例如,当使用FunctionTransformer
并使用自定义函数转换数据时,可能会发生这种情况。
使用pickle
、joblib
或cloudpickle
#
根据您的用例,您可以选择这三种方法之一来持久化和加载scikit-learn模型,它们都遵循相同的API
# Here you can replace pickle with joblib or cloudpickle
from pickle import dump
with open("filename.pkl", "wb") as f:
dump(clf, f, protocol=5)
建议使用protocol=5
来减少内存使用,并加快存储和加载模型中作为拟合属性存储的任何大型NumPy数组的速度。您也可以传递protocol=pickle.HIGHEST_PROTOCOL
,这在Python 3.8及更高版本中(撰写本文时)等同于protocol=5
。
之后在需要时,您可以从持久化文件中加载相同的对象
# Here you can replace pickle with joblib or cloudpickle
from pickle import load
with open("filename.pkl", "rb") as f:
clf = load(f)
10.5. 安全性和可维护性限制#
pickle
(以及joblib
和cloudpickle
的扩展)在设计上存在许多已知的安全漏洞,因此仅当工件(即pickle文件)来自受信任和验证的来源时才应使用。您绝不应从不受信任的来源加载pickle文件,就像您绝不应执行来自不受信任的来源的代码一样。
另请注意,任意计算都可以使用ONNX
格式表示,因此建议在沙盒环境中部署使用ONNX
的模型,以防止计算和内存漏洞。
另请注意,没有支持的方法可以加载使用不同版本scikit-learn训练的模型。虽然使用skops.io
、joblib
、pickle
或cloudpickle保存的模型可能在其他版本中加载,但这完全不受支持且不建议。还应记住,对此类数据执行的操作可能会产生不同且意想不到的结果,甚至可能导致您的Python进程崩溃。
为了使用未来版本的scikit-learn重建类似模型,应随pickle模型保存额外的元数据
训练数据,例如对不可变快照的引用
用于生成模型的Python源代码
scikit-learn及其依赖的版本
在训练数据上获得的交叉验证分数
这应该可以检查交叉验证分数是否与之前在相同范围内。
除少数例外,持久化模型应在不同操作系统和硬件架构之间可移植,前提是使用相同版本的依赖项和Python。如果您遇到不可移植的估计器,请在GitHub上提出问题。持久化模型通常使用Docker等容器部署到生产环境,以冻结环境和依赖项。
如果您想了解更多关于这些问题,请参考以下讲座
10.5.1. 在生产环境中复制训练环境#
如果使用的依赖项版本在训练和生产环境中可能不同,则在使用训练模型时可能会导致意外行为和错误。为防止此类情况,建议在训练和生产环境中使用相同的依赖项和版本。这些传递性依赖项可以借助包管理工具(如pip
、mamba
、conda
、poetry
、conda-lock
、pixi
等)进行固定。
并非总是能够在更新的软件环境中加载使用旧版本scikit-learn库及其依赖项训练的模型。相反,您可能需要使用所有库的新版本重新训练模型。因此,在训练模型时,重要的是记录训练配方(例如Python脚本)和训练集信息,以及所有依赖项的元数据,以便能够自动重建更新软件的相同训练环境。
InconsistentVersionWarning#
当估计器加载的scikit-learn版本与估计器序列化时使用的版本不一致时,会引发InconsistentVersionWarning
。可以捕获此警告以获取估计器原始序列化的版本。
from sklearn.exceptions import InconsistentVersionWarning
warnings.simplefilter("error", InconsistentVersionWarning)
try:
with open("model_from_previous_version.pickle", "rb") as f:
est = pickle.load(f)
except InconsistentVersionWarning as w:
print(w.original_sklearn_version)
10.5.2. 部署模型工件#
训练scikit-learn模型后的最后一步是部署模型。一旦训练好的模型成功加载,就可以部署它来管理不同的预测请求。这可以涉及根据规范,使用容器化将模型部署为Web服务,或采用其他模型部署策略。
10.6. 总结要点#
基于模型持久化的不同方法,每种方法的要点可以总结如下
ONNX
:它为持久化任何机器学习或深度学习模型(scikit-learn除外)提供了一种统一格式,并可用于模型推理(预测)。然而,它可能会导致与不同框架的兼容性问题。skops.io
:使用skops.io
可以轻松共享和部署训练好的scikit-learn模型。与基于pickle
的其他方法相比,它更安全,因为它不会加载任意代码,除非用户明确要求。此类代码需要在目标Python环境中打包并可导入。joblib
:当在多个Python进程中使用mmap_mode="r"
时,高效的内存映射技术使得使用相同的持久化模型更快。它还提供了简单的快捷方式来压缩和解压缩持久化对象,无需额外代码。然而,与任何其他基于pickle的持久化机制一样,从不受信任的来源加载模型时,它可能会触发恶意代码的执行。pickle
:它是Python原生的,大多数Python对象都可以使用pickle
进行序列化和反序列化,包括自定义Python类和函数,只要它们在目标环境中可导入的包中定义。虽然pickle
可以轻松保存和加载scikit-learn模型,但在从不受信任的来源加载模型时,它可能会触发恶意代码的执行。pickle
如果模型使用protocol=5
进行持久化,其内存效率也会很高,但它不支持内存映射。cloudpickle:它具有与
pickle
和joblib
(不含内存映射)相当的加载效率,但提供了额外的灵活性,可以序列化自定义Python代码,例如lambda表达式和交互式定义的函数和类。当需要持久化包含自定义Python组件(例如包装了训练脚本中定义或更一般地在任何可导入Python包之外定义的函数的sklearn.preprocessing.FunctionTransformer
)的管道时,它可能是最后的选择。请注意,cloudpickle不提供向前兼容性保证,您可能需要相同版本的cloudpickle以及定义模型所用的所有库的相同版本才能加载持久化模型。与其他基于pickle的持久化机制一样,在从不受信任的来源加载模型时,它可能会触发恶意代码的执行。