brier_score_loss#
- sklearn.metrics.brier_score_loss(y_true, y_proba, *, sample_weight=None, pos_label=None, labels=None, scale_by_half='auto')[source]#
计算布里尔分数损失。
布里尔分数损失越小越好,因此命名为“损失”。布里尔分数衡量的是预测概率与实际结果之间的均方差。布里尔分数是一个严格适当的评分规则。
在用户指南中阅读更多。
- 参数:
- y_true形状为 (n_samples,) 的类数组
真实目标。
- y_proba形状为 (n_samples,) 或 (n_samples, n_classes) 的类数组
预测概率。如果
y_proba.shape = (n_samples,)
,则提供的概率被假定为正类的概率。如果y_proba.shape = (n_samples, n_classes)
,则y_proba
中的列被假定按字母顺序对应于标签,如LabelBinarizer
所做。- sample_weight形状为 (n_samples,) 的类数组,默认为 None
样本权重。
- pos_labelint、float、bool 或 str,默认为 None
当
y_proba.shape = (n_samples,)
时正类的标签。如果未提供,pos_label
将按以下方式推断:如果
y_true
在 {-1, 1} 或 {0, 1} 中,pos_label
默认为 1;否则,如果
y_true
包含字符串,将引发错误,并且应明确指定pos_label
;否则,
pos_label
默认为较大的标签,即np.unique(y_true)[-1]
。
- labels形状为 (n_classes,) 的类数组,默认为 None
当
y_proba.shape = (n_samples, n_classes)
时类的标签。如果未提供,标签将从y_true
推断。版本 1.7 新增。
- scale_by_halfbool 或 “auto”,默认为 “auto”
当为 True 时,将布里尔分数除以 1/2 以使其落在 [0, 1] 范围内,而不是 [0, 2] 范围内。默认的“auto”选项仅在二分类(按惯例)时实现重新缩放到 [0, 1] 范围,但对于多分类仍保持原始的 [0, 2] 范围。
版本 1.7 新增。
- 返回:
- score浮点数
布里尔分数损失。
注释
对于从 \(C\) 个可能类别中标记的 \(N\) 个观测值,布里尔分数定义为
\[\frac{1}{N}\sum_{i=1}^{N}\sum_{c=1}^{C}(y_{ic} - \hat{p}_{ic})^{2}\]其中 \(y_{ic}\) 为 1 表示观测值
i
属于类别c
,否则为 0;\(\hat{p}_{ic}\) 是观测值i
属于类别c
的预测概率。布里尔分数范围在 \([0, 2]\) 之间。在二分类任务中,布里尔分数通常除以二,使其范围在 \([0, 1]\) 之间。它也可以写为
\[\frac{1}{N}\sum_{i=1}^{N}(y_{i} - \hat{p}_{i})^{2}\]其中 \(y_{i}\) 是二元目标,\(\hat{p}_{i}\) 是正类的预测概率。
参考
[1]示例
>>> import numpy as np >>> from sklearn.metrics import brier_score_loss >>> y_true = np.array([0, 1, 1, 0]) >>> y_true_categorical = np.array(["spam", "ham", "ham", "spam"]) >>> y_prob = np.array([0.1, 0.9, 0.8, 0.3]) >>> brier_score_loss(y_true, y_prob) 0.0375 >>> brier_score_loss(y_true, 1-y_prob, pos_label=0) 0.0375 >>> brier_score_loss(y_true_categorical, y_prob, pos_label="ham") 0.0375 >>> brier_score_loss(y_true, np.array(y_prob) > 0.5) 0.0 >>> brier_score_loss(y_true, y_prob, scale_by_half=False) 0.075 >>> brier_score_loss( ... ["eggs", "ham", "spam"], ... [[0.8, 0.1, 0.1], [0.2, 0.7, 0.1], [0.2, 0.2, 0.6]], ... labels=["eggs", "ham", "spam"] ... ) 0.146