获取 20 个新闻组数据集 #

sklearn.datasets.fetch_20newsgroups(*, data_home=None, subset='train', categories=None, shuffle=True, random_state=42, remove=(), download_if_missing=True, return_X_y=False, n_retries=3, delay=1.0)[source]#

加载20个新闻组数据集(分类)的文件名和数据。

必要时下载。

类别

20

样本总数

18846

维度

1

特征

文本

用户指南中了解更多信息。

参数:
data_homestr 或 path-like,默认为None

指定数据集的下载和缓存文件夹。如果为None,所有scikit-learn数据都存储在‘~/scikit_learn_data’子文件夹中。

subset{'train', 'test', 'all'},默认为'train'

选择要加载的数据集:'train'表示训练集,'test'表示测试集,'all'表示两者,顺序已打乱。

categoriesarray-like,dtype=str,默认为None

如果为None(默认),则加载所有类别。如果非None,则列出要加载的类别名称(忽略其他类别)。

shufflebool,默认为True

是否打乱数据:对于那些假设样本独立同分布(i.i.d.)的模型(如随机梯度下降)来说可能很重要。

random_stateint,RandomState 实例或 None,默认为42

确定数据集混洗的随机数生成。传递一个整数以在多次函数调用中获得可重复的输出。参见词汇表

removetuple,默认为()

可以包含 (‘headers’, ‘footers’, ‘quotes’) 的任何子集。这些都是将被检测并从新闻组帖子中删除的文本类型,防止分类器过度拟合元数据。

‘headers’ 删除新闻组标题,‘footers’ 删除帖子结尾看起来像签名的块,‘quotes’ 删除看起来像引用另一个帖子的行。

‘headers’ 遵循精确的标准;其他过滤器并不总是正确的。

download_if_missingbool,默认为True

如果为False,如果数据在本地不可用,则引发OSError,而不是尝试从源站点下载数据。

return_X_ybool,默认为False

如果为True,则返回(data.data, data.target)而不是Bunch对象。

在0.22版本中添加。

n_retriesint,默认为3

遇到HTTP错误时的重试次数。

在1.5版本中添加。

delayfloat,默认为1.0

两次重试之间的时间间隔(秒)。

在1.5版本中添加。

返回:
bunchBunch

类似字典的对象,具有以下属性。

data形状为 (n_samples,) 的列表

要学习的数据列表。

target: 形状为 (n_samples,) 的ndarray

目标标签。

filenames: 形状为 (n_samples,) 的列表

数据位置的路径。

DESCR: str

数据集的完整描述。

target_names: 形状为 (n_classes,) 的列表

目标类别的名称。

(data, target)如果 return_X_y=True 则为元组

两个ndarray的元组。第一个包含形状为(n_samples, n_classes)的二维数组,其中每一行代表一个样本,每一列代表特征。第二个形状为(n_samples,)的数组包含目标样本。

在0.22版本中添加。

示例

>>> from sklearn.datasets import fetch_20newsgroups
>>> cats = ['alt.atheism', 'sci.space']
>>> newsgroups_train = fetch_20newsgroups(subset='train', categories=cats)
>>> list(newsgroups_train.target_names)
['alt.atheism', 'sci.space']
>>> newsgroups_train.filenames.shape
(1073,)
>>> newsgroups_train.target.shape
(1073,)
>>> newsgroups_train.target[:10]
array([0, 1, 1, 1, 0, 1, 1, 0, 0, 0])