图拉索 (graphical_lasso)#
- sklearn.covariance.graphical_lasso(emp_cov, alpha, *, mode='cd', tol=0.0001, enet_tol=0.0001, max_iter=100, verbose=False, return_costs=False, eps=np.float64(2.220446049250313e-16), return_n_iter=False)[source]#
- L1正则化协方差估计器。 - 更多信息请参考 用户指南。 - v0.20版本变更: graph_lasso 已重命名为 graphical_lasso - 参数:
- emp_covshape 为 (n_features, n_features) 的 array-like
- 用于计算协方差估计的经验协方差。 
- alphafloat
- 正则化参数:alpha 越高,正则化越多,逆协方差越稀疏。范围为 (0, inf]。 
- mode{‘cd’, ‘lars’}, 默认值为 ‘cd’
- 使用的 Lasso 求解器:坐标下降或 LARS。对于非常稀疏的潜在图(其中 p > n),使用 LARS。其他情况下,建议使用数值上更稳定的 cd。 
- tolfloat, 默认值为 1e-4
- 声明收敛的容差:如果对偶间隙低于此值,则停止迭代。范围为 (0, inf]。 
- enet_tolfloat, 默认值为 1e-4
- 用于计算下降方向的弹性网络求解器的容差。此参数控制给定列更新的搜索方向的精度,而不是整体参数估计的精度。仅在 mode='cd' 时使用。范围为 (0, inf]。 
- max_iterint, 默认值为 100
- 最大迭代次数。 
- verbosebool, 默认值为 False
- 如果 verbose 为 True,则在每次迭代时打印目标函数和对偶间隙。 
- return_costsbool, 默认值为 False
- 如果 return_costs 为 True,则返回每次迭代的目标函数值和对偶间隙。 
- epsfloat, 默认值为 eps
- 计算 Cholesky 对角因子时的机器精度正则化。对于病态系统,请增加此值。默认为 - np.finfo(np.float64).eps。
- return_n_iterbool, 默认值为 False
- 是否返回迭代次数。 
 
- 返回:
- covarianceshape 为 (n_features, n_features) 的 ndarray
- 估计的协方差矩阵。 
- precisionshape 为 (n_features, n_features) 的 ndarray
- 估计的(稀疏)精度矩阵。 
- costs(目标函数, 对偶间隙) 对的列表
- 每次迭代的目标函数值和对偶间隙的列表。仅当 return_costs 为 True 时返回。 
- n_iterint
- 迭代次数。仅当 - return_n_iter设置为 True 时返回。
 
 - 另请参阅 - GraphicalLasso
- 使用 l1 正则化估计器的稀疏逆协方差估计。 
- GraphicalLassoCV
- 具有交叉验证选择的 l1 惩罚的稀疏逆协方差。 
 - 备注 - 用于解决此问题的算法是来自 Friedman 2008 Biostatistics 论文的 GLasso 算法。它与 R - glasso包中的算法相同。- 与 - glassoR 包的一个可能区别在于,对角系数未被惩罚。- 示例 - >>> import numpy as np >>> from sklearn.datasets import make_sparse_spd_matrix >>> from sklearn.covariance import empirical_covariance, graphical_lasso >>> true_cov = make_sparse_spd_matrix(n_dim=3,random_state=42) >>> rng = np.random.RandomState(42) >>> X = rng.multivariate_normal(mean=np.zeros(3), cov=true_cov, size=3) >>> emp_cov = empirical_covariance(X, assume_centered=True) >>> emp_cov, _ = graphical_lasso(emp_cov, alpha=0.05) >>> emp_cov array([[ 1.68..., 0.21..., -0.20...], [ 0.21..., 0.22..., -0.08...], [-0.20..., -0.08..., 0.23...]]) 
