维护两个语言的代码不太方便,可能会有语义不一样的地方。如果出现语义不同的地方,以Kotlin的实现为准。

  在之前的博文中我们讨论了二叉搜索树平衡二叉搜索树的理论实现,这次让我们用C++Kotlin来实现一个弱平衡树。

注意:以下代码实现的弱平衡树不是任何网上的弱平衡树(比如:红黑树、Splay等),是我完全按照前面的理论写的


  2022-04-22更新:

  突然发现判断平衡没必要比较左右子树数量,直接判断左右子树高度就行了,左右子树高度一样的时候旋不旋转对树的高度的影响都一样。


结构设计

  首先,我们肯定需要先设计一下代码结构。

  1. 需要一个类表示
  2. 需要一个类(结构体)表示节点

  这么看来,我们至少需要两个类:

1
2
3
4
5
class WeakBalanceTree<T : Comparable<T>> {

class Node<T : Comparable<T>>: Comparable<T>

}
1
2
3
4
5
6
7
8
9
10
11
12
template<class T>
class WeakBalanceTree {

//先声明一下是因为我们要把Node的实现放在后面
//不先声明的话类中就无法看到这个结构体
struct Node;

struct Node {

}

}

Node设计

  现在,我们再来看Node中需要记录些什么信息:

  1. 节点的value
  2. 与当前节点相连的其它三个节点
  3. 节点深度

  那么,我们很容易就能写出下面的实现:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
class Node<T : Comparable<T>>(value: T) : Comparable<T> {

var value: T = value
internal set
var deep = 1
internal set
var left: Node<T>? = null
internal set
var right: Node<T>? = null
internal set
var father: Node<T>? = null
internal set

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
struct Node {

/** 节点的值 */
T value;
/** 深度 */
int deep = 1;
/** 左子树 */
Node* left = nullptr;
/** 右子树 */
Node* right = nullptr;
/** 父亲节点 */
Node* father = nullptr;

explicit Node(const T& v): value(v) {}
}

  那么,一个节点需要干什么活呢?我们先假设它需要负责这些功能:

  1. 检查以该节点为根的树是否平衡
  2. 判断以该节点为根的树的重心
  3. 旋转树
  4. 维护树的平衡
  5. 维护树的深度
  6. 查找左子树最大值(或右子树最小值)

  不过其中的34都涉及到了树根节点的变换(这两个过程可能会修改树的根节点),放在结构体中不太合适,所以将其放在WeakBalanceTree中,剩余的放入结构体中。

  现在我们遇到了一个问题:如何快速判断树的重心?我们不可能每次判断重心的时候都把子树遍历一遍,这样子效率太低了,实际上我们直接比较左右子树的高度即可。

  同时,判断重心时有三种可能:偏左、偏右以及中心。为了用一个变量清晰地表示所有可能,我们选择再声明一个枚举类:

1
2
3
enum class BalanceDirection {
LEFT, RIGHT, CENTER
}
1
2
3
enum BalanceDirection {
LEFT, RIGHT, CENTER
};

  代码写起来就很简单了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class Node<T : Comparable<T>>(value: T) : Comparable<T> {

var value: T = value
internal set
var deep = 1
internal set
var left: Node<T>? = null
internal set
var right: Node<T>? = null
internal set
var father: Node<T>? = null
internal set

override fun compareTo(other: T) = value.compareTo(other)

fun hasLeft() = left != null
fun hasRight() = right != null
fun notLeft() = !hasLeft()
fun notRight() = !hasRight()

fun getLeftDeep() = if (notLeft()) 0 else left!!.deep
fun getRightDeep() = if (notRight()) 0 else right!!.deep

/**
* 把当前节点的其中一个子树替换为指定子树
* @param src 要被替换的子树对象
* @param dist 目的子树
*/

internal fun swapSonTree(src: Node<T>, dist: Node<T>?) {
if (src == left) left = dist
else right = dist
}

/** 交换两个节点中的值 */
internal fun swapNode(that: Node<T>) {
value = that.value.apply { that.value = value }
}

/** 检查以该节点为根的树是否平衡 */
fun checkBalance() = abs(getLeftDeep() - getRightDeep()) < 2

/** 判断平衡向哪边偏移 */
fun getBalance(): BalanceDirection {
val dif = getLeftDeep() - getRightDeep()
if (dif < 0) return RIGHT
if (dif == 0) return CENTER
return LEFT
}

/** 获取左子树中的最大值,如果不存在左子树则返回节点本身 */
internal fun findLeftMax(): Node<T> {
if (notLeft()) return this
var result = left
while (result!!.hasRight()) {
result = result.right
}
return result
}

/** 从指定节点开始向上维护节点深度 */
internal fun repairDeep() {
var point: Node<T>? = this
while (point != null) {
point = point.apply {
val newDeep = max(getLeftDeep(), getRightDeep()) + 1
if (newDeep == deep) return
deep = newDeep
father
}
}
}

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
struct Node {

/** 节点的值 */
T value;
/** 深度 */
int deep = 1;
/** 左子树 */
Node* left = nullptr;
/** 右子树 */
Node* right = nullptr;
/** 父亲节点 */
Node* father = nullptr;

explicit Node(const T& v): value(v) {}

inline bool operator<(const T& that) const {
return value < that;
}

//判断节点是否含有左子树
inline bool hasLeft() const {
return left != nullptr;
}

//判断节点是否含有右子树
inline bool hasRight() const {
return right != nullptr;
}

//判断节点是否不含有左子树
inline bool notLeft() const {
return !hasLeft();
}

//判断节点是否不含有右子树
inline bool notRight() const {
return !hasRight();
}

//检查以该节点为根的树是否平衡
inline bool checkBalance() const {
return abs(getLeftDeep() - getRightDeep()) < 2;
}

/**
* 把当前节点的其中一个子树替换为指定子树
* @param src 要被替换的子树对象
* @param dist 目的子树
*/

inline void swapSonTree(const Node* src, Node* dist) {
if (src == left) left = dist;
else right = dist;
}

/** 交换两个节点中的值 */
inline void swapNode(Node* that) {
if (that == this) return;
swap(value, that->value);
}

/** 判断平衡向哪边偏移 */
BalanceDirection getBalance() const {
int dif = getLeftDeep() - getRightDeep();
if (dif < 0) return RIGHT;
if (dif == 0) return CENTER;
return LEFT;
}

/** 从指定节点开始向上维护节点深度 */
void repairDeep() {
Node* point = this;
do {
int newDeep = max(point->getLeftDeep(), point->getRightDeep()) + 1;
if (newDeep == point->deep) break;
point->deep = newDeep;
point = point->father;
} while (point != nullptr);
}

/** 获取指定节点的左子树中的最大值 */
Node* findLeftMax() {
if (notLeft()) return this;
Node* result = left;
while (result->hasRight()) {
result = result->right;
}
return result;
}

/** 获取左子树树高,左子树不存在返回0 */
inline int getLeftDeep() const {
return notLeft() ? 0 : left->deep;
}

/** 获取右子树树高,右子树不存在返回0 */
inline int getRightDeep() const {
return notRight() ? 0 : right->deep;
}

};

工具函数

  在实现树的插入/移除功能之前我们要实现一些需要用到的工具函数:

  1. 维护平衡的函数
  2. 查找指定元素
  3. 旋转

前置

  前两个函数实现起来非常简单,我们直接给出代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
class WeakBalanceTree<T : Comparable<T>> {

/** 检查指定元素在树中是否存在 */
fun contain(value: T) = find(value) != null

/**
* 在树中查找指定元素
* @return 元素在树中的数据,不存在则返回nullptr
*/

fun find(value: T): Node<T>? {
var it = root
while (it != null) {
val cmp = it.compareTo(value)
if (cmp < 0) it = it.right
else if (cmp != 0) it = it.left
else return it
}
return null
}

/** 以指定节点为根维护树的平衡 */
private fun repairBalance(start: Node<T>) {
if (!start.checkBalance()) {
if (start.getBalance() == LEFT) rotateRight(start)
else rotateLeft(start)
}
}

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46

template<class T>
class WeakBalanceTree {

public:

/**
* 在树中查找指定元素
* @return 元素在树中的数据,不存在则返回nullptr
*/

const Node* find(const T& value) const {
auto it = root;
while (it != nullptr) {
if (*it < value) it = it->right;
else if (value < *it) it = it->left;
else return it;
}
return nullptr;
}

/** 判断指定元素在树中是否存在 */
inline const bool contain(const T& value) const {
return find(value) != nullptr;
}

private:

/**
* <p>在树中查找指定元素
* <p>写两个一样的函数是因为一个是<code>const</code>一个不是
* @return 元素在树中的数据,不存在则返回nullptr
*/

inline Node* _find(const T& value) {
return const_cast<Node*>(find(value));
}

/** 以某节点为根维护平衡 */
void repairBalance(Node* start) {
if (!start->checkBalance()) {
if (start->getBalance() == LEFT) rotateRight(start);
else rotateLeft(start);
}
}

}

旋转

  现在来到了重点,旋转树地代码应该怎么写?如果你还记得我们前面说的旋转的规律,那么实现起来就很简单:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
class WeakBalanceTree<T : Comparable<T>> {

/**
* 以指定节点为根进行左旋
* @throws IllegalArgumentException 如果右子树不存在
*/

private fun rotateLeft(point: Node<T>) {
if (point.notRight()) throw IllegalArgumentException("右子树不存在")
val srcRight = point.right!!
val srcLeftDeep = point.getLeftDeep()
//如果右子树平衡偏左,先右旋
if (srcRight.getBalance() == LEFT) rotateRight(srcRight)
with (point) {
//更新父亲节点的信息
if (father != null) father!!.swapSonTree(this, srcRight)
//将右子树的左子树设置为当前节点的左子树
right = srcRight.left
//将右子树的左子树设置为根
srcRight.left = this
//更新两个节点的父亲节点
srcRight.father = father
father = srcRight
//维护深度
deep = srcLeftDeep + 1
srcRight.repairDeep()
//必要时修改树的根节点
if (this == root) root = srcRight
}
}

/**
* 以指定节点为根进行右旋
* @throws IllegalArgumentException 如果左子树不存在
*/

private fun rotateRight(point: Node<T>) {
if (point.notLeft()) throw IllegalArgumentException("左子树不存在")
val srcLeft = point.left!!
val srcRightDeep = point.getRightDeep()
if (srcLeft.getBalance() == RIGHT) rotateLeft(srcLeft)
with (point) {
if (father != null) father!!.swapSonTree(this, srcLeft)
left = srcLeft.right
srcLeft.right = this
srcLeft.father = father
father = srcLeft
deep = srcRightDeep + 1
srcLeft.repairDeep()
if (this == root) root = srcLeft
}
}

}

KotlinC++里面的注释不太一样,Kotlin里面的注释更全一点

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
template<class T>
class WeakBalanceTree {

private:

/**
* 以指定节点为根进行左旋
* @throws invalid_argument 如果右子树不存在
*/

void rotateLeft(Node* point) {
if (point->notRight()) throw invalid_argument("右子树不存在");
auto srcRight = point->right;
int srcLeftDeep = point->getLeftDeep();
//如果右子树平衡偏左,先右旋
if (point->right->getBalance() == LEFT) rotateRight(point->right);
//更新父亲节点的信息
if (point->father != nullptr) point->father->swapSonTree(point, srcRight);
//将右子树的左子树设置为右子树
point->right = srcRight->left;
//将右子树的左子树设置为根节点
srcRight->left= point;
//更新父亲节点信息
srcRight->father = point->father;
point->father = srcRight;
//维护深度
point->deep = srcLeftDeep + 1;
srcRight->repairDeep();
if (point == root) root = srcRight;
}

/**
* 以指定节点为根进行右旋
* @throws invalid_argument 如果左子树不存在
*/

void rotateRight(Node* point) {
if (point->notLeft()) throw invalid_argument("左子树不存在");
auto srcLeft = point->left;
int srcRightDeep = point->getRightDeep();
//如果左子树平衡偏右,先左旋
if (point->left->getBalance() == RIGHT) rotateLeft(point->left);
if (point->father != nullptr) point->father->swapSonTree(point, srcLeft);
point->left = srcLeft->right;
srcLeft->right = point;
srcLeft->father = point->father;
point->father = srcLeft;
point->deep = srcRightDeep + 1;
srcLeft->repairDeep();
if (point == root) root = srcLeft;
}

}

编辑树

  有了上面的基础,那么我们实现插入和移除就很简单了:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
class WeakBalanceTree<T : Comparable<T>> {

/**
* 向树中插入指定元素
*
* @return 插入的元素再树中的节点
*/

fun insert(value: T): Node<T> {
if (root == null) {
root = Node(value)
return root!!
}
var it = root
val result: Node<T>
while (true) {
val cmp = it!!.compareTo(value)
if (cmp < 0) {
if (it.right == null) {
result = Node(value)
it.right = result
break
} else it = it.right
} else if (cmp != 0) {
if (it.left == null) {
result = Node(value)
it.left = result
break
} else it = it.left
} else return it
}
result.father = it
result.repairDeep()
if (it!!.father != null) repairBalance(it.father!!)
return result
}

/**
* 从树中删除指定节点
*
* 注意:**函数内会修改传入的point的内容**
*/

fun erase(point: Node<T>) {
val max = point.findLeftMax()
point.swapNode(max)
max.father!!.swapSonTree(max, null)
max.father!!.repairDeep()
repairBalance(max.father!!)
}

/** 从树中删除指定节点 */
fun erase(value: T): Boolean {
val point = find(value) ?: return false
erase(point)
return true
}

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
template<class T>
class WeakBalanceTree {

/**
* 向树中插入指定元素
* @return 插入的元素在树中的节点
*/

const Node* insert(const T& value) {
if (root == nullptr) return root = new Node(value);
auto it = root;
Node* result;
while (true) {
if (value < *it) {
if (it->left == nullptr) {
result = it->left = new Node(value);
break;
} else it = it->left;
} else if (*it < value) {
if (it->right == nullptr) {
result = it->right = new Node(value);
break;
} else it = it->right;
} else return it;
}
result->father = it;
result->repairDeep();
if (it->father != nullptr) repairBalance(it->father);
return result;
}

/**
* <p>从树中移除指定节点
* <p>注意:<b>函数内会修改传入的指针</b>
*/

void erase(Node* point) {
auto max = point->findLeftMax();
point->swapNode(max);
max->father->swapSonTree(max, nullptr);
max->father->repairDeep();
repairBalance(max->father);
delete max;
}

/** 从树中移除指定节点 */
inline bool erase(const T& value) {
auto point = _find(value);
if (point == nullptr) return false;
erase(point);
return true;
}

}

  至此,我们就成功的写出了一个弱平衡树,我们这里再贴出完整的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
import WeakBalanceTree.BalanceDirection.*
import kotlin.math.abs
import kotlin.math.max

/**
* 弱平衡树
* @author EmptyDreams
*/

class WeakBalanceTree<T : Comparable<T>> {

/** 根节点 */
private var root: Node<T>? = null

/**
* 向树中插入指定元素
*
* @return 插入的元素再树中的节点
*/

fun insert(value: T): Node<T> {
if (root == null) {
root = Node(value)
return root!!
}
var it = root
val result: Node<T>
while (true) {
val cmp = it!!.compareTo(value)
if (cmp < 0) {
if (it.right == null) {
result = Node(value)
it.right = result
break
} else it = it.right
} else if (cmp != 0) {
if (it.left == null) {
result = Node(value)
it.left = result
break
} else it = it.left
} else return it
}
result.father = it
result.repairDeep()
if (it!!.father != null) repairBalance(it.father!!)
return result
}

/** 检查指定元素在树中是否存在 */
fun contain(value: T) = find(value) != null

/**
* 在树中查找指定元素
* @return 元素在树中的数据,不存在则返回nullptr
*/

fun find(value: T): Node<T>? {
var it = root;
while (it != null) {
val cmp = it.compareTo(value)
if (cmp < 0) it = it.right
else if (cmp != 0) it = it.left
else return it
}
return null
}

/**
* 从树中删除指定节点
*
* 注意:**函数内会修改传入的point的内容**
*/

fun erase(point: Node<T>) {
val max = point.findLeftMax()
point.swapNode(max)
max.father!!.swapSonTree(max, null)
max.father!!.repairDeep()
repairBalance(max.father!!)
}

/** 从树中删除指定节点 */
fun erase(value: T): Boolean {
val point = find(value) ?: return false
erase(point)
return true
}

/** 以指定节点为根维护树的平衡 */
private fun repairBalance(start: Node<T>) {
if (!start.checkBalance()) {
if (start.getBalance() == LEFT) rotateRight(start)
else rotateLeft(start)
}
}

/**
* 以指定节点为根进行左旋
* @throws IllegalArgumentException 如果右子树不存在
*/

private fun rotateLeft(point: Node<T>) {
if (point.notRight()) throw IllegalArgumentException("右子树不存在")
val srcRight = point.right!!
val srcLeftDeep = point.getLeftDeep()
//如果右子树平衡偏左,先右旋
if (srcRight.getBalance() == LEFT) rotateRight(srcRight)
with (point) {
//更新父亲节点的信息
if (father != null) father!!.swapSonTree(this, srcRight)
//将右子树的左子树设置为当前节点的左子树
right = srcRight.left
//将右子树的左子树设置为根
srcRight.left = this
//更新两个节点的父亲节点
srcRight.father = father
father = srcRight
//维护深度
deep = srcLeftDeep + 1
srcRight.repairDeep()
//必要时修改树的根节点
if (this == root) root = srcRight
}
}

/**
* 以指定节点为根进行右旋
* @throws IllegalArgumentException 如果左子树不存在
*/

private fun rotateRight(point: Node<T>) {
if (point.notLeft()) throw IllegalArgumentException("左子树不存在")
val srcLeft = point.left!!
val srcRightDeep = point.getRightDeep()
if (srcLeft.getBalance() == RIGHT) rotateLeft(srcLeft)
with (point) {
if (father != null) father!!.swapSonTree(this, srcLeft)
left = srcLeft.right
srcLeft.right = this
srcLeft.father = father
father = srcLeft
deep = srcRightDeep + 1
srcLeft.repairDeep()
if (this == root) root = srcLeft
}
}

/** 树节点 */
class Node<T : Comparable<T>>(value: T) : Comparable<T> {

var value: T = value
internal set
var deep = 1
internal set
var left: Node<T>? = null
internal set
var right: Node<T>? = null
internal set
var father: Node<T>? = null
internal set

override fun compareTo(other: T) = value.compareTo(other)

fun hasLeft() = left != null
fun hasRight() = right != null
fun notLeft() = !hasLeft()
fun notRight() = !hasRight()

fun getLeftDeep() = if (notLeft()) 0 else left!!.deep
fun getRightDeep() = if (notRight()) 0 else right!!.deep

/**
* 把当前节点的其中一个子树替换为指定子树
* @param src 要被替换的子树对象
* @param dist 目的子树
*/

internal fun swapSonTree(src: Node<T>, dist: Node<T>?) {
if (src == left) left = dist
else right = dist
}

/** 交换两个节点中的值 */
internal fun swapNode(that: Node<T>) {
value = that.value.apply { that.value = value }
}

/** 检查以该节点为根的树是否平衡 */
fun checkBalance() = abs(getLeftDeep() - getRightDeep()) < 2

/** 判断平衡向哪边偏移 */
fun getBalance(): BalanceDirection {
val dif = getLeftDeep() - getRightDeep()
if (dif < 0) return RIGHT
if (dif == 0) return CENTER
return LEFT
}

/** 获取左子树中的最大值,如果不存在左子树则返回节点本身 */
internal fun findLeftMax(): Node<T> {
if (notLeft()) return this
var result = left
while (result!!.hasRight()) {
result = result.right
}
return result
}

/** 从指定节点开始向上维护节点深度 */
internal fun repairDeep() {
var point: Node<T>? = this
while (point != null) {
point = point.apply {
val newDeep = max(getLeftDeep(), getRightDeep()) + 1
if (newDeep == deep) return
deep = newDeep
father
}
}
}

}

enum class BalanceDirection {
LEFT, RIGHT, CENTER
}

}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
#include <regex>
#include <exception>

using std::swap;
using std::max;
using std::invalid_argument;

enum BalanceDirection { LEFT, RIGHT, CENTER };

template<class T>
class WeakBalanceTree {

public:

struct Node;

/**
* 向树中插入指定元素
* @return 插入的元素在树中的节点
*/

const Node* insert(const T& value) {
if (root == nullptr) return root = new Node(value);
auto it = root;
Node* result;
while (true) {
if (value < *it) {
if (it->left == nullptr) {
result = it->left = new Node(value);
break;
} else it = it->left;
} else if (*it < value) {
if (it->right == nullptr) {
result = it->right = new Node(value);
break;
} else it = it->right;
} else return it;
}
result->father = it;
result->repairDeep();
if (it->father != nullptr) repairBalance(it->father);
return result;
}

/**
* 在树中查找指定元素
* @return 元素在树中的数据,不存在则返回nullptr
*/

const Node* find(const T& value) const {
auto it = root;
while (it != nullptr) {
if (*it < value) it = it->right;
else if (value < *it) it = it->left;
else return it;
}
return nullptr;
}

/** 判断指定元素在树中是否存在 */
inline const bool contain(const T& value) const {
return find(value) != nullptr;
}

/**
* <p>从树中移除指定节点
* <p>注意:<b>函数内会修改传入的指针</b>
*/

void erase(Node* point) {
auto max = point->findLeftMax();
point->swapNode(max);
max->father->swapSonTree(max, nullptr);
max->father->repairDeep();
repairBalance(max->father);
delete max;
}

/** 从树中移除指定节点 */
inline bool erase(const T& value) {
auto point = _find(value);
if (point == nullptr) return false;
erase(point);
return true;
}

struct Node {

/** 节点的值 */
T value;
/** 深度 */
int deep = 1;
/** 左子树 */
Node* left = nullptr;
/** 右子树 */
Node* right = nullptr;
/** 父亲节点 */
Node* father = nullptr;

explicit Node(const T& v): value(v) {}

inline bool operator<(const T& that) const {
return value < that;
}

//判断节点是否含有左子树
inline bool hasLeft() const {
return left != nullptr;
}

//判断节点是否含有右子树
inline bool hasRight() const {
return right != nullptr;
}

//判断节点是否不含有左子树
inline bool notLeft() const {
return !hasLeft();
}

//判断节点是否不含有右子树
inline bool notRight() const {
return !hasRight();
}

//检查以该节点为根的树是否平衡
inline bool checkBalance() const {
return abs(getLeftDeep() - getRightDeep()) < 2;
}

/**
* 把当前节点的其中一个子树替换为指定子树
* @param src 要被替换的子树对象
* @param dist 目的子树
*/

inline void swapSonTree(const Node* src, Node* dist) {
if (src == left) left = dist;
else right = dist;
}

/** 交换两个节点中的值 */
inline void swapNode(Node* that) {
if (that == this) return;
swap(value, that->value);
}

/** 判断平衡向哪边偏移 */
BalanceDirection getBalance() const {
int dif = getLeftDeep() - getRightDeep();
if (dif < 0) return RIGHT;
if (dif == 0) return CENTER;
return LEFT;
}

/** 从指定节点开始向上维护节点深度 */
void repairDeep() {
Node* point = this;
do {
int newDeep = max(point->getLeftDeep(), point->getRightDeep()) + 1;
if (newDeep == point->deep) break;
point->deep = newDeep;
point = point->father;
} while (point != nullptr);
}

/** 获取指定节点的左子树中的最大值 */
Node* findLeftMax() {
if (notLeft()) return this;
Node* result = left;
while (result->hasRight()) {
result = result->right;
}
return result;
}

/** 获取左子树树高,左子树不存在返回0 */
inline int getLeftDeep() const {
return notLeft() ? 0 : left->deep;
}

/** 获取右子树树高,右子树不存在返回0 */
inline int getRightDeep() const {
return notRight() ? 0 : right->deep;
}

};

private:

Node* root = nullptr;

/**
* 在树中查找指定元素
* @return 元素在树中的数据,不存在则返回nullptr
*/

Node* _find(const T& value) {
return const_cast<Node*>(find(value));
}

/** 以某节点为根维护平衡 */
void repairBalance(Node* start) {
if (!start->checkBalance()) {
if (start->getBalance() == LEFT) rotateRight(start);
else rotateLeft(start);
}
}

/**
* 以指定节点为根进行左旋
* @throws invalid_argument 如果右子树不存在
*/

void rotateLeft(Node* point) {
if (point->notRight()) throw invalid_argument("右子树不存在");
auto srcRight = point->right;
int srcLeftDeep = point->getLeftDeep();
//如果右子树平衡偏左,先右旋
if (point->right->getBalance() == LEFT) rotateRight(point->right);
//更新父亲节点的信息
if (point->father != nullptr) point->father->swapSonTree(point, srcRight);
//将右子树的左子树设置为右子树
point->right = srcRight->left;
//将右子树的左子树设置为根节点
srcRight->left= point;
//更新父亲节点信息
srcRight->father = point->father;
point->father = srcRight;
//维护深度
point->deep = srcLeftDeep + 1;
srcRight->repairDeep();
if (point == root) root = srcRight;
}

/**
* 以指定节点为根进行右旋
* @throws invalid_argument 如果左子树不存在
*/

void rotateRight(Node* point) {
if (point->notLeft()) throw invalid_argument("左子树不存在");
auto srcLeft = point->left;
int srcRightDeep = point->getRightDeep();
//如果左子树平衡偏右,先左旋
if (point->left->getBalance() == RIGHT) rotateLeft(point->left);
if (point->father != nullptr) point->father->swapSonTree(point, srcLeft);
point->left = srcLeft->right;
srcLeft->right = point;
srcLeft->father = point->father;
point->father = srcLeft;
point->deep = srcRightDeep + 1;
srcLeft->repairDeep();
if (point == root) root = srcLeft;
}

};

template<class T>
constexpr bool operator<(const T& value, const typename WeakBalanceTree<T>::Node& node) {
return value < node.value;
}

创作不易,扫描下方打赏二维码支持一下吧ヾ(≧▽≦*)o