train_test_split#
- sklearn.model_selection.train_test_split(*arrays, test_size=None, train_size=None, random_state=None, shuffle=True, stratify=None)[source]#
将数组或矩阵拆分为随机训练集和测试子集。
一个快速工具函数,将输入验证、
next(ShuffleSplit().split(X, y))以及对输入数据的应用封装成一个单行调用,用于将数据拆分(并可选地进行子采样)。Read more in the User Guide.
- 参数:
- *arrays具有相同长度/shape[0]的可索引序列
允许的输入包括列表、numpy数组、scipy稀疏矩阵或pandas数据帧。
- test_size浮点数或整数,默认值=None
如果是浮点数,应介于0.0和1.0之间,表示用于测试拆分的数据集比例。如果是整数,表示测试样本的绝对数量。如果为None,则该值设置为训练大小的补集。如果
train_size也为None,则默认设置为0.25。- train_sizefloat or int, default=None
如果是浮点数,应介于0.0和1.0之间,表示用于训练拆分的数据集比例。如果是整数,表示训练样本的绝对数量。如果为None,则该值会自动设置为测试大小的补集。
- random_stateint, RandomState instance or None, default=None
控制在应用拆分之前对数据进行的洗牌。传入一个整数可以在多次函数调用中获得可重现的输出。请参阅词汇表。
- shufflebool, default=True
在拆分之前是否对数据进行洗牌。如果shuffle=False,则stratify必须为None。
- stratifyarray-like,默认值=None
如果不为None,则使用此作为类别标签,以分层方式拆分数据。在用户指南中了解更多信息。
- 返回:
- splitting列表,长度=2 * len(arrays)
包含输入数据的训练集和测试集拆分的列表。
版本0.16中新增: 如果输入是稀疏的,则输出将是
scipy.sparse.csr_matrix。否则,输出类型与输入类型相同。
示例
>>> import numpy as np >>> from sklearn.model_selection import train_test_split >>> X, y = np.arange(10).reshape((5, 2)), range(5) >>> X array([[0, 1], [2, 3], [4, 5], [6, 7], [8, 9]]) >>> list(y) [0, 1, 2, 3, 4]
>>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, test_size=0.33, random_state=42) ... >>> X_train array([[4, 5], [0, 1], [6, 7]]) >>> y_train [2, 0, 3] >>> X_test array([[2, 3], [8, 9]]) >>> y_test [1, 4]
>>> train_test_split(y, shuffle=False) [[0, 1, 2], [3, 4]]
>>> from sklearn import datasets >>> iris = datasets.load_iris(as_frame=True) >>> X, y = iris['data'], iris['target'] >>> X.head() sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) 0 5.1 3.5 1.4 0.2 1 4.9 3.0 1.4 0.2 2 4.7 3.2 1.3 0.2 3 4.6 3.1 1.5 0.2 4 5.0 3.6 1.4 0.2 >>> y.head() 0 0 1 0 2 0 3 0 4 0 ...
>>> X_train, X_test, y_train, y_test = train_test_split( ... X, y, test_size=0.33, random_state=42) ... >>> X_train.head() sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) 96 5.7 2.9 4.2 1.3 105 7.6 3.0 6.6 2.1 66 5.6 3.0 4.5 1.5 0 5.1 3.5 1.4 0.2 122 7.7 2.8 6.7 2.0 >>> y_train.head() 96 1 105 2 66 1 0 0 122 2 ... >>> X_test.head() sepal length (cm) sepal width (cm) petal length (cm) petal width (cm) 73 6.1 2.8 4.7 1.2 18 5.7 3.8 1.7 0.3 118 7.7 2.6 6.9 2.3 78 6.0 2.9 4.5 1.5 76 6.8 2.8 4.8 1.4 >>> y_test.head() 73 1 18 0 118 2 78 1 76 1 ...