BallTree#
- class sklearn.neighbors.BallTree#
用于快速解决广义 N 点问题的 BallTree
在用户指南中阅读更多信息。
- 参数:
- Xarray-like of shape (n_samples, n_features)
n_samples 是数据集中点的数量,n_features 是参数空间的维度。注意:如果 X 是一个 C 连续的双精度浮点数数组,则不会复制数据。否则,将进行内部复制。
- leaf_size正整数, default=40
切换到暴力搜索的点数量。更改 leaf_size 不会影响查询结果,但会显著影响查询速度和存储构建树所需的内存。存储树所需的内存量大约是 n_samples / leaf_size。对于指定的
leaf_size
,叶节点保证满足leaf_size <= n_points <= 2 * leaf_size
,除非n_samples < leaf_size
。- metricstr 或 DistanceMetric64 对象, default=’minkowski’
用于距离计算的度量。默认是“minkowski”,当 p = 2 时,这会得到标准的欧氏距离。BallTree 有效度量的列表由属性
valid_metrics
给出。有关任何距离度量的更多信息,请参阅 scipy.spatial.distance 的文档以及distance_metrics
中列出的度量。- 附加关键字被传递给距离度量类。
- 注意:KDTree 和 Ball Tree 不支持 metric 参数中的可调用函数。函数调用开销将导致非常差的性能。
- 和 球树。函数调用开销会导致非常差的性能。
- 属性:
- data内存视图
训练数据
- valid_metrics: 字符串列表
有效距离度量列表。
示例
查询 K 近邻
>>> import numpy as np >>> from sklearn.neighbors import BallTree >>> rng = np.random.RandomState(0) >>> X = rng.random_sample((10, 3)) # 10 points in 3 dimensions >>> tree = BallTree(X, leaf_size=2) >>> dist, ind = tree.query(X[:1], k=3) >>> print(ind) # indices of 3 closest neighbors [0 3 1] >>> print(dist) # distances to 3 closest neighbors [ 0. 0.19662693 0.29473397]
Pickle 和 Unpickle 一棵树。请注意,树的状态保存在 pickle 操作中:unpickling 时无需重建树。
>>> import numpy as np >>> import pickle >>> rng = np.random.RandomState(0) >>> X = rng.random_sample((10, 3)) # 10 points in 3 dimensions >>> tree = BallTree(X, leaf_size=2) >>> s = pickle.dumps(tree) >>> tree_copy = pickle.loads(s) >>> dist, ind = tree_copy.query(X[:1], k=3) >>> print(ind) # indices of 3 closest neighbors [0 3 1] >>> print(dist) # distances to 3 closest neighbors [ 0. 0.19662693 0.29473397]
在给定半径内查询邻居
>>> import numpy as np >>> rng = np.random.RandomState(0) >>> X = rng.random_sample((10, 3)) # 10 points in 3 dimensions >>> tree = BallTree(X, leaf_size=2) >>> print(tree.query_radius(X[:1], r=0.3, count_only=True)) 3 >>> ind = tree.query_radius(X[:1], r=0.3) >>> print(ind) # indices of neighbors within distance 0.3 [3 0 1]
计算高斯核密度估计
>>> import numpy as np >>> rng = np.random.RandomState(42) >>> X = rng.random_sample((100, 3)) >>> tree = BallTree(X) >>> tree.kernel_density(X[:3], h=0.1, kernel='gaussian') array([ 6.94114649, 7.83281226, 7.2071716 ])
计算两点自相关函数
>>> import numpy as np >>> rng = np.random.RandomState(0) >>> X = rng.random_sample((30, 3)) >>> r = np.linspace(0, 1, 5) >>> tree = BallTree(X) >>> tree.two_point_correlation(X, r) array([ 30, 62, 278, 580, 820])
- get_arrays()#
获取数据和节点数组。
- 返回:
- arrays: 数组元组
用于存储树数据、索引、节点数据和节点边界的数组。
- get_n_calls()#
获取调用次数。
- 返回:
- n_calls: 整数
距离计算的调用次数
- get_tree_stats()#
获取树状态。
- 返回:
- tree_stats: 整数元组
(修剪次数,叶子数量,分裂次数)
- kernel_density(X, h, kernel='gaussian', atol=0, rtol=1E-8, breadth_first=True, return_log=False)#
使用创建树时指定的距离度量,在点 X 处计算给定核的核密度估计。
- 参数:
- Xarray-like of shape (n_samples, n_features)
要查询的点数组。最后一维应与训练数据的维度匹配。
- h浮点数
核的带宽
- kernel字符串, default=”gaussian”
指定要使用的核。选项有 - 'gaussian' - 'tophat' - 'epanechnikov' - 'exponential' - 'linear' - 'cosine' 默认是 kernel = 'gaussian'
- atol浮点数, default=0
指定所需结果的绝对容差。如果真实结果是
K_true
,则返回的结果K_ret
满足abs(K_true - K_ret) < atol + rtol * K_ret
默认值为零(即机器精度)。- rtol浮点数, default=1e-8
指定所需结果的相对容差。如果真实结果是
K_true
,则返回的结果K_ret
满足abs(K_true - K_ret) < atol + rtol * K_ret
默认值为1e-8
(即机器精度)。- breadth_first布尔值, default=False
如果为 True,则使用广度优先搜索。如果为 False(默认),则使用深度优先搜索。对于紧凑核和/或高容差,广度优先通常更快。
- return_log布尔值, default=False
返回结果的对数。对于窄核,这可能比直接返回结果更精确。
- 返回:
- density形状为 X.shape[:-1] 的 ndarray
(对数)-密度评估数组
- query(X, k=1, return_distance=True, dualtree=False, breadth_first=False)#
查询树以获取 k 个最近邻
- 参数:
- Xarray-like of shape (n_samples, n_features)
要查询的点数组
- k整数, default=1
要返回的最近邻数量
- return_distance布尔值, default=True
如果为 True,则返回距离和索引的元组 (d, i);如果为 False,则返回数组 i
- dualtree布尔值, default=False
如果为 True,则对查询使用双树形式:为查询点构建一棵树,并使用这对树来高效搜索此空间。当点数量变得很大时,这可以带来更好的性能。
- breadth_first布尔值, default=False
如果为 True,则以广度优先方式查询节点。否则,以深度优先方式查询节点。
- sort_results布尔值, default=True
如果为 True,则返回时每个点的距离和索引会排序,使第一列包含最近的点。否则,邻居以任意顺序返回。
- 返回:
- i如果 return_distance == False
- (d,i)如果 return_distance == True
- d形状为 X.shape[:-1] + (k,) 的 ndarray, dtype=double
每个条目给出对应点的邻居距离列表。
- i形状为 X.shape[:-1] + (k,) 的 ndarray, dtype=int
每个条目给出对应点的邻居索引列表。
- query_radius(X, r, return_distance=False, count_only=False, sort_results=False)#
查询树以获取半径 r 内的邻居
- 参数:
- Xarray-like of shape (n_samples, n_features)
要查询的点数组
- r返回邻居的距离范围
r 可以是单个值,也可以是形状为 x.shape[:-1] 的值数组,如果每个点需要不同的半径。
- return_distance布尔值, default=False
如果为 True,则返回到每个点的邻居距离;如果为 False,则仅返回邻居。请注意,与 query() 方法不同,在此处设置 return_distance=True 会增加计算时间。对于 return_distance=False,并非所有距离都需要显式计算。结果默认不排序:请参阅
sort_results
关键字。- count_only布尔值, default=False
如果为 True,则仅返回距离 r 内的点的数量;如果为 False,则返回距离 r 内所有点的索引。如果 return_distance==True,设置 count_only=True 将导致错误。
- sort_results布尔值, default=False
如果为 True,则在返回之前将对距离和索引进行排序。如果为 False,则结果将不排序。如果 return_distance == False,设置 sort_results = True 将导致错误。
- 返回:
- count如果 count_only == True
- ind如果 count_only == False and return_distance == False
- (ind, dist)如果 count_only == False and return_distance == True
- count形状为 X.shape[:-1] 的 ndarray, dtype=int
每个条目给出对应点距离 r 内的邻居数量。
- ind形状为 X.shape[:-1] 的 ndarray, dtype=object
每个元素是一个 numpy 整数数组,列出对应点邻居的索引。注意,与 k 近邻查询的结果不同,返回的邻居默认不按距离排序。
- dist形状为 X.shape[:-1] 的 ndarray, dtype=object
每个元素是一个 numpy 双精度浮点数数组,列出与索引 i 对应的距离。
- reset_n_calls()#
将调用次数重置为 0。
- two_point_correlation(X, r, dualtree=False)#
计算两点相关函数
- 参数:
- Xarray-like of shape (n_samples, n_features)
要查询的点数组。最后一维应与训练数据的维度匹配。
- rarray-like
一个一维距离数组
- dualtree布尔值, default=False
如果为 True,则使用双树算法。否则,使用单树算法。双树算法对于大型 N 可以有更好的扩展性。
- 返回:
- countsndarray
counts[i] 包含距离小于或等于 r[i] 的点对数量