Abracadabra

kNN and kd-tree

k-近邻算法

工作原理

存在一组带标签的训练集[1],每当有新的不带标签的样本[2]出现时,将训练集中数据的特征与测试集的特征逐个比较,通过某种测度来提取出与测试集最相似的k个训练集样本,然后将这k个样本中占大多数[4]的标签赋予测试集样本。

伪代码

对测试集中的每个点依次执行如下操作:

  1. 计算训练集中的每个点与当前点的距离

  2. 按照距离递增次序排序

  3. 在排序好的点中选取前k个点

  4. 统计出k个点中不同类别的出现频率

  5. 选择频率最高的类别为当前点的预测分类

代码实现

首先创建测试数据集

1
from numpy import *
1
2
3
4
def createDataset():
group = array([[1.0, 1.1], [1.0, 1.0], [0, 0], [0, 0.1]])
labels = ['A', 'A', 'B', 'B']
return group, labels

返回预测分类

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = tile(inX, (dataSetSize, 1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndices = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndices[i]]
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1
# the return of sorted() is a list and its item is a tuple
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0] # returns the predict class label

进一步探索

k-近邻算法的缺点在于当数据量很大时,拥有不可接受的空间复杂度以及时间复杂度

其次该算法最关键的地方在与超参k的选取。当k选取的过小时容易造成过拟合,反之容易造成欠拟合。考虑两个极端情况,当k=1时,该算法又叫最近邻算法;当k=N[3]时,表示直接从原始数据中选取占比最大的类别,显然这个算法太naive了。

为了解决kNN算法时间复杂度的问题,最关键的便是在于如何对数据进行快速的k近邻搜索,一种解决方法是引入kd树来进行加速。

kd树

简介

以二维空间为例,假设有6个二维数据点{(2, 3), (5, 4), (9, 6), (4, 7), (8, 1), (7, 2)},可以用下图来表明kd树所能达到的效果:

kd-overview

kd树算法主要分为两个部分:

  1. kd树数据结构的建立
  2. 在kd树上进行查找

kd树是一种对k维空间上的数据点进行存储以便进行高效查找的树形数据结构,属于二叉树。构造kd树相当于不断用垂直于坐标轴的超平面对k维空间进行划分,构成一系列k维超矩形区域。kd树的每一个结点对应一个超矩形区域,表示一个空间范围。

数据结构

下表给出每个结点主要包含的数据结构:

域名数据类型描述
Node-data数据矢量数据集中的某个数据点,k维矢量
Split整数垂直于分割超平面的方向轴序号
Leftkd树由位于该节点分割超平面左子空间内所有数据点构成的kd树
Rightkd树由位于该节点分割超平面右子空间内所有数据点构成的kd树
建立树伪代码

下面给出构建kd树的伪代码:

算法:构建k-d树(createKDTree)
输入:数据点集Data-set
输出:Kd,类型为k-d tree
1. If Data-set为空,则返回空的k-d tree
2. 调用节点生成程序: (1)确定split域:对于所有描述子数据(特征矢量),统计它们在每个维上的数据方差。以SURF特征为例,描述子为64维,可计算64个方差。挑选出最大值,对应的维就是split域的值。数据方差大表明沿该坐标轴方向上的数据分散得比较开,在这个方向上进行数据分割有较好的分辨率; (2)确定Node-data域:数据点集Data-set按其第split域的值排序。位于正中间的那个数据点被选为Node-data。此时新的Data-set’ = Data-set\Node-data(除去其中Node-data这一点)。
3. dataleft = {d属于Data-set’ && d[split] ≤ Node-data[split]} dataright = {d属于Data-set’ && d[split] > Node-data[split]}
4. left = 由(dataleft)建立的k-d tree,即递归调用createKDTree(dataleft)并设置left的parent域为Kd; right = 由(dataright)建立的k-d tree,即调用createKDTree(dataleft)并设置right的parent域为Kd。
实例

用最开始的6个二维数据点的例子,来具体化这个过程:

  1. 确定split域的首先该取的值。分别计算x,y方向上数据的方差得知x方向上的方差最大,所以split域值首先取0,也就是x轴方向;

  2. 确定Node-data的域值。根据x轴方向的值2,5,9,4,8,7排序选出中值为7,所以Node-data = (7, 2)。这样,该节点的分割超平面就是通过(7, 2)并垂直于split = 0(x轴)的直线x = 7;

  3. 确定左子空间和右子空间。分割超平面x = 7将整个空间分为两部分,如下图所示。x < = 7的部分为左子空间,包含3个节点{(2, 3), (5, 4), (4, 7)};另一部分为右子空间,包含2个节点{(9, 6), (8, 1)}。

    kd-construct-step1

如算法所述,k-d树的构建是一个递归的过程。然后对左子空间和右子空间内的数据重复根节点的过程就可以得到下一级子节点(5,4)和(9,6)(也就是左右子空间的’根’节点),同时将空间和数据集进一步细分。如此反复直到空间中只包含一个数据点,如图1所示。最后生成的k-d树如下图所示。

kd-construct-step2

注意:每一级节点旁边的’x’和’y’表示以该节点分割左右子空间时split所取的值。

这里进行一点补充说明,kd树其实就是二叉树,其与普通的二叉查找树不同之处在于,其每一层根据split的维度进行二叉拆分。具体来说,根据上图,第一层的拆分是根据x,那么其左孩子的x值就小于根结点的x值,右孩子则反之。y值则没有规定(这里出现的左大右小只是纯粹的巧合)。第二层是根据y值来进行split,因此第三层的规律显而易见。

代码实现

运行环境:Windows 10 Pro 64-bit x64-based(Ver. 10.0.14393), Python 3.5.2, Anaconda 4.1.1(64-bit), IPython 5.0.0, Windows CMD,

kdTreeCreate.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from kdTreeNode import *
def createDataSet():
""" Create the test dataset.
Returns:
A numpy array that contains the test data.
"""
dataSet = np.array([[2, 3], [5, 4], [9, 6],
[4, 7], [8, 1], [7, 2]])
return dataSet
def split(dataSet):
""" Split the given dataset.
Returns:
LeftDataSet: A kdTreeNode object.
RightDataSet: A kdTreeNode object.
NodeData: A tuple.
"""
# Ensure the dimension to split
dimenIndex = np.var(dataSet, axis=0).argmax()
partitionDataSet = dataSet[:, dimenIndex]
# print(partitionDataSet)
# Ensure the position to split
partitionDataSetArgSort = partitionDataSet.argsort()
# print(partitionDataSetArgSort)
lenOfPartitionDataSetArgSort = len(partitionDataSetArgSort)
if lenOfPartitionDataSetArgSort % 2 == 0:
posIndex = lenOfPartitionDataSetArgSort // 2
splitIndex = partitionDataSetArgSort[posIndex]
else:
posIndex = (lenOfPartitionDataSetArgSort - 1) // 2
splitIndex = partitionDataSetArgSort[posIndex]
# print(splitIndex)
# Split
nodeData = dataSet[splitIndex]
leftIndeies = partitionDataSetArgSort[:posIndex]
rightIndeies = partitionDataSetArgSort[posIndex + 1:]
leftDataSet = dataSet[leftIndeies]
rightDataSet = dataSet[rightIndeies]
return nodeData, dimenIndex, leftDataSet, rightDataSet
def createKDTree(dataSet):
""" Create the KD tree.
Returns:
A kdTreeNode object.
"""
if len(dataSet) == 0:
return
nodeData, dimenIndex, leftDataSet, rightDataSet = split(dataSet)
# print(nodeData, dimenIndex, leftDataSet, rightDataSet)
node = kdTreeNode(nodeData, dimenIndex)
node.setLeft(createKDTree(leftDataSet))
node.setRight(createKDTree(rightDataSet))
return node
def midTravel(node):
if node is None:
return
midTravel(node.getLeft())
print(node.getData())
midTravel(node.getRight())
if __name__ == "__main__":
dataSet = createDataSet()
node = createKDTree(dataSet)
midTravel(node)

kdTreeNode.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
class kdTreeNode(object):
""" Class of k-d tree nodes
"""
def __init__(self, data=None, split=None, left=None, right=None):
self.__data = data
self.__split = split
self.__left = left
self.__right = right
def getData(self):
return self.__data
def setData(self, data):
self.__data = data
def getSplit(self):
return self.__split
def setSplit(self, split):
self.__split = split
def getLeft(self):
return self.__left
def setLeft(self, left):
self.__left = left
def getRight(self):
return self.__right
def setRight(self, right):
self.__right = right

运行结果:

1
2
3
4
5
6
7
8
In [1]: run kdTreeCreate.py
-------------------------------------
[2 3]
[5 4]
[4 7]
[7 2]
[8 1]
[9 6]

时间复杂度:

N个K维数据进行查找操作时时间复杂度为 $t=O(KN^{2})$

下面就要在已经建立好的kd树上进行查找操作。

查找

kd树中进行的查找与普通的查找操作存在较大的差异,其目的是为了找出与查询点距离最近的点。

星号表示要查询的点(2.1, 3.1)。通过二叉搜索,顺着搜索路径很快就能找到最邻近的近似点,也就是叶子节点(2, 3)。而找到的叶子节点并不一定就是最邻近的,最邻近肯定距离查询点更近,应该位于以查询点为圆心且通过叶子节点的圆域内。为了找到真正的最近邻,还需要进行’回溯’操作:算法沿搜索路径反向查找是否有距离查询点更近的数据点。此例中先从(7, 2)点开始进行二叉查找,然后到达(5, 4),最后到达(2, 3),此时搜索路径中的节点为<(7, 2), (5, 4), (2, 3)>,首先以(2, 3)作为当前最近邻点,计算其到查询点(2.1, 3.1)的距离为0.1414,然后回溯到其父节点(5, 4),并判断在该父节点的其他子节点空间中是否有距离查询点更近的数据点。以(2.1, 3.1)为圆心,以0.1414为半径画圆,如图4所示。发现该圆并不和超平面y = 4交割,因此不用进入(5, 4)节点右子空间中去搜索。

再回溯到(7, 2),以(2.1, 3.1)为圆心,以0.1414为半径的圆更不会与x = 7超平面交割,因此不用进入(7, 2)右子空间进行查找。至此,搜索路径中的节点已经全部回溯完,结束整个搜索,返回最近邻点(2, 3),最近距离为0.1414。

kd-tree-search-1

一个复杂点了例子如查找点为(2, 4.5)。同样先进行二叉查找,先从(7, 2)查找到(5, 4)节点,在进行查找时是由y = 4为分割超平面的,由于查找点为y值为4.5,因此进入右子空间查找到(4, 7),形成搜索路径<(7, 2), (5, 4), (4, 7)>,取(4, 7)为当前最近邻点,计算其与目标查找点的距离为3.202。然后回溯到(5, 4),计算其与查找点之间的距离为3.041。以(2, 4.5)为圆心,以3.041为半径作圆。

kd-tree-search-2

可见该圆和y = 4超平面交割,所以需要进入(5, 4)左子空间进行查找。此时需将(2, 3)节点加入搜索路径中得<(7, 2), (2, 3)>。回溯至(2, 3)叶子节点,(2, 3)距离(2, 4.5)比(5, 4)要近,所以最近邻点更新为(2, 3),最近距离更新为1.5。回溯至(7, 2),以(2, 4.5)为圆心1.5为半径作圆,并不和x = 7分割超平面交割。至此,搜索路径回溯完。返回最近邻点(2, 3),最近距离1.5。

kd-tree-search-2

k-d树查询算法的伪代码如下所示。

查找伪代码
算法: k-d树最邻近查找
输入:Kd, //k-d tree类型
target //查询数据点
输出:nearest, //最邻近数据点
dist //最邻近数据点和查询点间的距离
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
1. If Kd为NULL,则设dist为infinite并返回
2. //进行二叉查找,生成搜索路径
Kd_point = &Kd; //Kd-point中保存k-d tree根节点地址
nearest = Kd_point -> Node-data; //初始化最近邻点
while(Kd_point)
  push(Kd_point)到search_path中; //search_path是一个堆栈结构,存储着搜索路径节点指针
/*** If Dist(nearest,target) > Dist(Kd_point -> Node-data,target)
    nearest = Kd_point -> Node-data; //更新最近邻点
    Max_dist = Dist(Kd_point,target); //更新最近邻点与查询点间的距离 ***/
  s = Kd_point -> split; //确定待分割的方向
  If target[s] <= Kd_point -> Node-data[s] //进行二叉查找
    Kd_point = Kd_point -> left;
  else
    Kd_point = Kd_point ->right;
nearest = search_path中最后一个叶子节点; //注意:二叉搜索时不比计算选择搜索路径中的最邻近点,这部分已被注释
Max_dist = Dist(nearest,target); //直接取最后叶子节点作为回溯前的初始最近邻点
3. //回溯查找
while(search_path != NULL)
  back_point = 从search_path取出一个节点指针; //从search_path堆栈弹栈
  s = back_point -> split; //确定分割方向
  If Dist(target[s],back_point -> Node-data[s]) < Max_dist //判断还需进入的子空间
    If target[s] <= back_point -> Node-data[s]
      Kd_point = back_point -> right; //如果target位于左子空间,就应进入右子空间
    else
      Kd_point = back_point -> left; //如果target位于右子空间,就应进入左子空间
    将Kd_point压入search_path堆栈;
  If Dist(nearest,target) > Dist(Kd_Point -> Node-data,target)
    nearest = Kd_point -> Node-data; //更新最近邻点
    Min_dist = Dist(Kd_point -> Node-data,target); //更新最近邻点与查询点间的距离
代码实现

kdTreeSearch.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
def cal_dist(node, target):
""" Calculate the distance between the node
and the target.
Arguments:
node: The kd-tree's one node.
target: Search target.
Returns:
dist: The distance between the two nodes.
"""
node_data = np.array(node)
target_data = np.array(target)
square_dist_vector = (node_data - target_data) ** 2
square_dist = np.sum(square_dist_vector)
dist = square_dist ** 0.5
return dist
def search(root_node, target):
""" Search the nearest node of the target node
in the kd-tree that root node is the root_node
Arguments:
root_node: The kd-tree's root node.
target: Search target.
Returns:
nearest: The nearest node of the target node in the kd-tree.
min_dist: The nearest distance.
"""
if root_node is None:
min_dist = float('inf')
return min_dist
# Two-fork search
kd_point = root_node # Save the root node
nearest = kd_point.getData() # Initial the nearest node
search_path = [] # Initial the search stack
while kd_point:
search_path.append(kd_point)
split_index = kd_point.getSplit() # Ensure the split path
if target[split_index] <= kd_point.getData()[split_index]:
kd_point = kd_point.getLeft()
else:
kd_point = kd_point.getRight()
nearest = search_path.pop().getData()
min_dist = cal_dist(nearest, target)
# Retrospect search
while search_path:
back_point = search_path.pop()
# Ensure the back-split path
back_split_index = back_point.getSplit()
# Judge if needs to enter the subspace
if cal_dist(target[back_split_index],
back_point.getData()[back_split_index]) < min_dist:
# If the target is in the left subspace, then enter the right
if target[back_split_index] <= back_point.getData()[back_split_index]:
kd_point = back_point.getRight()
# Otherwise enter the left
else:
kd_point = back_point.getLeft()
# Add the node to the search path
if kd_point is not None:
search_path.append(kd_point)
if cal_dist(nearest, target) > cal_dist(kd_point.getData(), target):
# Update the nearest node
nearest = kd_point.getData()
# Update the maximum distance
min_dist = cal_dist(kd_point.getData(), target)
return nearest, min_dist

运行结果:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
In [1]: run kdTreeCreate.py
-------------------------------------
[2 3]
[5 4]
[4 7]
[7 2]
[8 1]
[9 6]
In [2]: node
-------------------------------------
Out [2]: <kdTreeNode.kdTreeNode at 0x26bff22f160>
In [3]: import kdTreeSearch
In [4]: nearest, min_dist = kdTreeSearch.search(node, [2.1, 3.1])
In [5]: nearest
-------------------------------------
Out [5]: array([2, 3])
In [6]: min_dist
-------------------------------------
Out [6]: 0.14142135623730964
In [7]: nearest, min_dist = kdTreeSearch.search(node, [2, 4.5])
In [8]: nearest
-------------------------------------
Out [8]: array([2, 3])
In [9]: min_dist
-------------------------------------
Out [9]: 1.5

时间复杂度:

N个结点的K维kd树进行查找操作时最坏时间复杂度为 $t_{worst}=O(KN^{1-1/k})$

根据相关研究,当数据维度为K时,只有当数据量N满足 $N>>2^K$ 时,才能达到高效的搜索(K<20,超过20维时可采用ball-tree算法),所以引出了一系列的改进算法(BBF算法,和一系列M树、VP树、MVP树等高维空间索引树),留待后续补充。

采用kd树的k-近邻算法

接下来便是将两者相结合。

[1] 说是训练集其实是不准确的,因为k-近邻算法是一个无参数方法,只存在一个超参k,因此其不存在一个训练的过程

[2] 测试集

[3] N代表训练集的数目

[4] 多数表决