Skip to content

11221: 【原1221】bst

题目

题目描述

author: DS TA 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/1221

Description

实现二叉查找树,支持插入,删除,查找,删除小于 x 的所有元素,删除大于 x 的所有元素,删除大于 a 且小于 b 的所有元素,查找第 i 小的元素。

Input Format

第 1 行: 一个数, n, 表示总操作个数。

接下来 n 行: 每行首先一个单词, 表示操作的名称, 这一行接下来的格式每种操作不同:

"insert": 插入, 接下来一个整数, x, 表示被插入的元素

"delete": 删除, 接下来一个整数, x, 表示被删除的元素(若树中有重复删除任意一个)

"delete_less_than": 删除小于 x 的所有元素, 接下来一个整数, x

"delete_greater_than": 删除大于 x 的所有元素, 接下来一个整数, x

"delete_interval": 删除大于 a 且小于 b 的所有元素, 接下来两个整数, a, b

"find": 查找, 接下来一个整数, x, 表示被查找的元素

"find_ith": 查找第 i 小的元素, 接下来一个整数, i

Output Format

对于每个 "find" 和 "find_ith" 操作,输出一行。

其中对于 "find" 操作,输出 Y/N,表示是否查找到询问的元素。

对于 "find_ith" 操作,输出第 i 小的元素值,若不存在输出 N。

注意: 对于数列 1, 2, 2, 3, 我们认为第 1 小的元素是 1, 第 2 小的元素是 2, 第 3 小的元素还是 2, 第 4 小的元素是 3

Sample Input

22
insert 42
insert 42
insert 43
find 42
find 44
find_ith 2
find_ith 4
delete 42
delete_greater_than 42
find_ith 1
insert 1
insert 2
insert 3
insert 4
insert 5
delete_less_than 2
delete_interval 3 5
find 1
find 2
find 3
find 4
find 5

Sample Output

Y
N
42
N
42
N
Y
Y
N
Y

Limits

保证任意时刻树中元素不超过 5000 个,操作数小于300000

BugenZhao's solution

//
// Created by BugenZhao on 2019/4/25.
//

template<typename Key, typename Val>
class BPair {
public:
    Key key;
    Val val;

    BPair(Key key, Val val) : key(key), val(val) {}

    BPair() = delete;
};

template<typename Key, typename Val>
class BBinarySearchTree {
    class Node {
    public:
        BPair<Key, Val> data;
        Node *left;
        Node *right;

        explicit Node(const BPair<Key, Val> &data, Node *left = nullptr, Node *right = nullptr) :
                data(data), left(left), right(right) {}
    };

    Node *root;
    int size;
    int find_ith_curCount;
    Node *find_ith_ans;

private:
    BPair<Key, Val> *get(Node *node, const Key &key) {
        if (node == nullptr || node->data.key == key)
            return &node->data;
        if (key < node->data.key)
            return get(node->left, key);
        else
            return get(node->right, key);
    }

    void put(Node *&node, const Key &key, const Val &val) {
        if (node == nullptr)
            node = new Node({key, val});
        else if (key < node->data.key)
            put(node->left, key, val);
        else if (key > node->data.key)
            put(node->right, key, val);
        else
            node->data.val = val;
    }

    void removeMin(Node *&node) {
        if (node == nullptr) return;
        if (node->left == nullptr) {
            auto oldNode = node;
            node = node->right;
            delete node;
        } else {
            removeMin(node->left);
        }
    }

    void remove(Node *&node, const Key &key) {
        if (node == nullptr)
            return;
        if (key < node->data.key)
            remove(node->left, key);
        else if (key > node->data.key)
            remove(node->right, key);
        else if (node->right == nullptr) {
            auto oldNode = node;
            node = node->left;
            delete oldNode;
        } else if (node->left == nullptr) {
            auto oldNode = node;
            node = node->right;
            delete oldNode;
        } else {
            auto p = node->right;
            while (p->left != nullptr) p = p->left;
            node->data = p->data;
            remove(node->right, node->data.key);
        }
    }

    void clear(Node *&node) {
        if (node == nullptr) return;
        clear(node->left);
        clear(node->right);
        delete node;
        --size;
    }

    void delete_less_than(Node *&node, const Key &key) {
        if (node == nullptr) return;
        delete_less_than(node->left, key);
        delete_less_than(node->right, key);
        if (node->data.key < key) remove(node, node->data.key);
    }

    void delete_greater_than(Node *&node, const Key &key) {
        if (node == nullptr) return;
        delete_greater_than(node->left, key);
        delete_greater_than(node->right, key);
        if (node->data.key > key) remove(node, node->data.key);
    }

    void delete_interval(Node *&node, const Key &a, const Key &b) {
        if (node == nullptr) return;
        delete_interval(node->left, a, b);
        delete_interval(node->right, a, b);
        if (node->data.key > a && node->data.key < b) remove(node, node->data.key);
    }

    void find_ith(Node *node, int i) {
        if (find_ith_ans) return;
        if (node == nullptr) return;
        find_ith(node->left, i);
        if (find_ith_ans) return; // awesome
        if ((find_ith_curCount -= node->data.val) <= 0) {
            find_ith_ans = node;
            return;
        }
        find_ith(node->right, i);
    }

public:
    BBinarySearchTree() : root(nullptr), size(0) {}

    BPair<Key, Val> *get(const Key &key) {
        return get(root, key);
    }

    void put(const Key &key, const Val &val) {
        put(root, key, val);
    }

    void remove(const Key &key) {
        remove(root, key);
    }

    void clear() {
        clear(root);
    }

    virtual ~BBinarySearchTree() {
        clear();
    }

    void delete_less_than(const Key &key) {
        delete_less_than(root, key);
    }

    void delete_greater_than(const Key &key) {
        delete_greater_than(root, key);
    }

    void delete_interval(const Key &a, const Key &b) {
        delete_interval(root, a, b);
    }

    BPair<Key, Val> *find_ith(int i) {
        if (i <= 0) return nullptr;
        find_ith_curCount = i;
        find_ith_ans = nullptr;
        find_ith(root, i);
        if (find_ith_ans) return &find_ith_ans->data;
        else return nullptr;
    }
};

#include <iostream>
#include <string>

using std::ios, std::cin, std::cout, std::endl, std::string;
using ll = long long;

int main() {
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    BBinarySearchTree<int, int> bst;

    int n;
    string cmd;
    int a, b;

    cin >> n;
    while (n--) {
        cin >> cmd;
        if (cmd == "insert") {
            cin >> a;
            auto p = bst.get(a);
            if (p) ++(p->val);
            else bst.put(a, 1);
        } else if (cmd == "delete") {
            cin >> a;
            auto p = bst.get(a);
            if (p) --(p->val);
        } else if (cmd == "find") {
            cin >> a;
            auto p = bst.get(a);
            if (p && p->val) cout << "Y\n";
            else cout << "N\n";
        } else if (cmd == "delete_less_than") {
            cin >> a;
            bst.delete_less_than(a);
        } else if (cmd == "delete_greater_than") {
            cin >> a;
            bst.delete_greater_than(a);
        } else if (cmd == "delete_interval") {
            cin >> a >> b;
            bst.delete_interval(a, b);
        } else if (cmd == "find_ith") {
            cin >> a;
            auto p = bst.find_ith(a);
            if (p) cout << p->key << '\n';
            else cout << "N\n";
        }
    }
}

ligongzzz's solution

#include "iostream"
#include "cstring"
using namespace std;

class bst {
public:
    class node {
    public:
        int val = 0;
        node* lchild = nullptr, * rchild = nullptr;
        node* parent = nullptr;
    };

    class iterator {
    public:
        node* cur_pos = nullptr;

        //重载++i
        iterator& operator++() {
            if (cur_pos == nullptr)
                return *this;
            //假如有右孩子
            if (cur_pos->rchild != nullptr) {
                for (cur_pos = cur_pos->rchild; cur_pos->lchild != nullptr;)
                    cur_pos = cur_pos->lchild;
            }
            //假如是落单的根节点
            else if (cur_pos->parent == nullptr) {
                cur_pos = nullptr;
            }
            //假如是左叶子
            else if (cur_pos->parent->lchild == cur_pos)
                cur_pos = cur_pos->parent;
            //假如是右叶子
            else {
                auto last = cur_pos;
                for (cur_pos = cur_pos->parent;
                    cur_pos != nullptr && cur_pos->rchild == last;
                    cur_pos = cur_pos->parent)
                    last = cur_pos;
            }
            return *this;
        }

        //重载i++
        iterator operator++(int) {
            auto temp = *this;
            if (cur_pos == nullptr)
                return temp;
            //假如有右孩子
            if (cur_pos->rchild != nullptr) {
                for (cur_pos = cur_pos->rchild; cur_pos->lchild != nullptr;)
                    cur_pos = cur_pos->lchild;
            }
            //假如是落单的根节点
            else if (cur_pos->parent == nullptr) {
                cur_pos = nullptr;
            }
            //假如是左叶子
            else if (cur_pos->parent->lchild == cur_pos)
                cur_pos = cur_pos->parent;
            //假如是右叶子
            else {
                auto last = cur_pos;
                for (cur_pos = cur_pos->parent;
                    cur_pos != nullptr && cur_pos->rchild == last;
                    cur_pos = cur_pos->parent)
                    last = cur_pos;
            }
            return temp;
        }

        bool operator==(const iterator& other) {
            return cur_pos == other.cur_pos;
        }

        bool operator!=(const iterator& other) {
            return cur_pos != other.cur_pos;
        }

        int& operator*() {
            return cur_pos->val;
        }
    };

    node* root = nullptr;

    iterator begin() {
        iterator temp;
        auto p = root;
        for (; p->lchild; p = p->lchild);
        temp.cur_pos = p;
        return temp;
    }

    iterator end() {
        iterator temp;
        temp.cur_pos = nullptr;
        return temp;
    }

    iterator find(const int& val) {
        iterator ans;
        int length = 0;
        for (auto p = root; p;) {
            if (p->val == val) {
                ans.cur_pos = p;
                return ans;
            }
            else if (val < p->val) {
                p = p->lchild;
            }
            else {
                p = p->rchild;
            }
        }
        ans.cur_pos = nullptr;
        return ans;
    }

    void insert(const int& val) {
        if (!root) {
            root = new node;
            root->val = val;
            return;
        }
        //寻找
        auto p = root;
        for (; p;) {
            if (p->val == val) {
                //统一插入右子树
                if (p->rchild) {
                    p->rchild->parent = new node;
                    p->rchild->parent->val = val;
                    p->rchild->parent->rchild = p->rchild;
                    p->rchild = p->rchild->parent;
                    p->rchild->parent = p;
                }
                else {
                    p->rchild = new node;
                    p->rchild->val = val;
                    p->rchild->parent = p;
                }
                return;
            }
            if (val < p->val) {
                if (p->lchild)
                    p = p->lchild;
                else {
                    p->lchild = new node;
                    p->lchild->parent = p;
                    p = p->lchild;
                    break;
                }
            }
            else {
                if (p->rchild)
                    p = p->rchild;
                else {
                    p->rchild = new node;
                    p->rchild->parent = p;
                    p = p->rchild;
                    break;
                }
            }
        }
        //增加
        p->val = val;
    }

    iterator erase(const iterator& pos) {
        iterator ans = pos;
        auto p = pos.cur_pos;
        //如果是叶子结点则直接删除
        if (!p->lchild && !p->rchild) {
            ++ans;
            if (p == root)
                root = nullptr;
            else if (p->parent->lchild == p)
                p->parent->lchild = nullptr;
            else
                p->parent->rchild = nullptr;
            delete p;
        }
        //如果只有左孩子
        else if (p->lchild && !p->rchild) {
            ++ans;
            if (p == root) {
                root = p->lchild;
                p->lchild->parent = nullptr;
            }
            else if (p->parent->lchild == p) {
                p->parent->lchild = p->lchild;
                p->lchild->parent = p->parent;
            }
            else {
                p->parent->rchild = p->lchild;
                p->lchild->parent = p->parent;
            }
            delete p;
        }
        //如果有右孩子
        else if (p->rchild) {
            auto q = p->rchild;
            for (; q->lchild; q = q->lchild);
            p->val = q->val;
            if (q->parent->lchild == q)
                q->parent->lchild = q->rchild;
            else
                q->parent->rchild = q->rchild;
            if (q->rchild)
                q->rchild->parent = q->parent;
            delete q;
        }
        return ans;
    }

    void erase(const int& val) {
        auto p = root;
        for (; p;) {
            if (p->val == val) {
                //如果是叶子结点则直接删除
                if (!p->lchild && !p->rchild) {
                    if (p == root)
                        root = nullptr;
                    else if (p->parent->lchild == p)
                        p->parent->lchild = nullptr;
                    else
                        p->parent->rchild = nullptr;
                    delete p;
                }
                //如果只有左孩子
                else if (p->lchild && !p->rchild) {
                    if (p == root) {
                        root = p->lchild;
                        p->lchild->parent = nullptr;
                    }
                    else if (p->parent->lchild == p) {
                        p->parent->lchild = p->lchild;
                        p->lchild->parent = p->parent;
                    }
                    else {
                        p->parent->rchild = p->lchild;
                        p->lchild->parent = p->parent;
                    }
                    delete p;
                }
                //如果有右孩子
                else if (p->rchild) {
                    auto q = p->rchild;
                    for (; q->lchild; q = q->lchild);
                    p->val = q->val;
                    if (q->parent->lchild == q)
                        q->parent->lchild = q->rchild;
                    else
                        q->parent->rchild = q->rchild;
                    if (q->rchild)
                        q->rchild->parent = q->parent;
                    delete q;
                }
                return;
            }
            else if (val < p->val)
                p = p->lchild;
            else
                p = p->rchild;
        }
    }
};

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);

    bst setData;

    int num;
    cin >> num;
    for (; num > 0; num--) {
        char op[100];
        cin >> op;
        if (strcmp(op, "insert") == 0) {
            int temp;
            cin >> temp;
            setData.insert(temp);
        }
        else if (strcmp(op, "delete") == 0) {
            int temp;
            cin >> temp;
            if(auto iter=setData.find(temp);iter!=setData.end())
                setData.erase(iter);
        }
        else if (strcmp(op, "delete_less_than") == 0) {
            int temp;
            cin >> temp;
            for (auto p = setData.begin(); p != setData.end(); ) {
                if (*p < temp) {
                    p = setData.erase(p);
                }
                else break;
            }
        }
        else if (strcmp(op, "delete_greater_than") == 0) {
            int temp;
            cin >> temp;
            for (auto p = setData.begin();p!=setData.end(); ) {
                if (*p > temp) {
                    p = setData.erase(p);
                }
                else ++p;
            }
        }
        else if (strcmp(op, "delete_interval") == 0) {
            int l, r;
            cin >> l >> r;
            for (auto p = setData.begin(); p != setData.end();) {
                if (*p > l && (*p) < r)
                    p = setData.erase(p);
                else if (*p >= r)
                    break;
                else ++p;
            }
        }
        else if (strcmp(op, "find") == 0) {
            int temp;
            cin >> temp;
            if (setData.find(temp) != setData.end())
                cout << "Y" << endl;
            else cout << "N" << endl;
        }
        else if (strcmp(op, "find_ith") == 0) {
            int temp;
            bool flag = false;
            cin >> temp;
            int i = 1;
            for (auto p = setData.begin(); p != setData.end(); ++p, ++i) {
                if (i == temp) {
                    flag = true;
                    cout << *p << endl;
                    break;
                }
            }
            if (!flag)
                cout << "N" << endl;
        }
    }


    return 0;
}

Neight99's solution

#include <cstring>
#include <iostream>

using namespace std;

int findN;
bool flag;

class BinarySearchTree {
    struct Node {
        int data;
        Node *left;
        Node *right;

        Node(int x = 0, Node *l = 0, Node *r = 0)
            : data(x), left(l), right(r) {}

        Node &operator=(Node &other) {
            if (&other != this) {
                data = other.data;
                left = other.left;
                right = other.right;
            }

            return *this;
        }
    };

    Node *root;
    int sum;
    void clear(Node *&rhs);
    void insert(int x, Node *&rhs);
    bool find(int x, Node *&rhs);
    void find_ith(int x, Node *&rhs);
    void deleteNode(int x, Node *&rhs);
    void deleteLess(int x, Node *&rhs);
    void deleteGreater(int x, Node *&rhs);
    void deleteInterval(int low, int high, Node *&rhs);

   public:
    BinarySearchTree() : root(0), sum(0) {}
    ~BinarySearchTree() { clear(root); }
    void insert(int x) { insert(x, root); }
    bool find(int x) { return find(x, root); }
    void find_ith(int pos) {
        if (pos > sum) {
            flag = 0;
            return;
        }
        flag = 0;
        findN = 0;
        find_ith(pos, root);
    }
    void deleteEqual(int x) {
        deleteNode(x, root);
        if (sum == 0) {
            root = NULL;
        }
    }
    void deleteLess(int x) {
        deleteLess(x, root);
        if (sum == 0) {
            root = NULL;
        }
    }
    void deleteGreater(int x) {
        deleteGreater(x, root);
        if (sum == 0) {
            root = NULL;
        }
    }
    void deleteInterval(int low, int high) {
        deleteInterval(low, high, root);
        if (sum == 0) {
            root = NULL;
        }
    }
};

void BinarySearchTree::clear(Node *&rhs) {
    if (rhs == 0) {
        return;
    } else {
        clear(rhs->left);
        clear(rhs->right);
        delete rhs;
        sum--;
    }
}

void BinarySearchTree::insert(int x, Node *&rhs) {
    if (rhs == 0) {
        sum++;
        rhs = new Node(x);
    } else if (x <= rhs->data) {
        insert(x, rhs->left);
    } else {
        insert(x, rhs->right);
    }
}

bool BinarySearchTree::find(int x, Node *&rhs) {
    if (rhs == 0) {
        return 0;
    } else if (rhs->data == x) {
        return true;
    } else if (x < rhs->data) {
        return find(x, rhs->left);
    } else {
        return find(x, rhs->right);
    }
}
void BinarySearchTree::find_ith(int x, Node *&rhs) {
    if (findN > x) {
        return;
    }
    if (rhs->left != 0) {
        find_ith(x, rhs->left);
    }
    if (x == ++findN) {
        cout << rhs->data << '\n';
        flag = 1;
        return;
    }
    if (rhs->right != 0) {
        find_ith(x, rhs->right);
    }
}

void BinarySearchTree::deleteNode(int x, Node *&rhs) {
    if (rhs == 0) {
        return;
    }
    if (x < rhs->data) {
        deleteNode(x, rhs->left);
    } else if (x > rhs->data) {
        deleteNode(x, rhs->right);
    } else if (rhs->left != 0 && rhs->right != 0) {
        Node *p = rhs->right;
        while (p->left != 0) {
            p = p->left;
        }
        rhs->data = p->data;
        deleteNode(rhs->data, rhs->right);
    } else {
        Node *clean = rhs;
        rhs = (rhs->left != 0) ? rhs->left : rhs->right;
        delete clean;
        sum--;
    }
}

void BinarySearchTree::deleteLess(int x, Node *&rhs) {
    if (rhs == 0) {
        return;
    }
    while (rhs != 0 && x > rhs->data) {
        clear(rhs->left);
        Node *temp = rhs->right;
        delete rhs;
        rhs = temp;
        sum--;
    }
    if (rhs != 0 && x <= rhs->data) {
        deleteLess(x, rhs->left);
    }
}

void BinarySearchTree::deleteGreater(int x, Node *&rhs) {
    if (rhs == 0) {
        return;
    }
    while (rhs != 0 && x < rhs->data) {
        clear(rhs->right);
        Node *temp = rhs->left;
        delete rhs;
        rhs = temp;
        sum--;
    }
    if (rhs != 0 && x >= rhs->data) {
        deleteGreater(x, rhs->right);
    }
}

void BinarySearchTree::deleteInterval(int low, int high, Node *&rhs) {
    if (low >= high || rhs == 0) {
        return;
    }
    while (rhs != 0 && rhs->data < high && rhs->data > low) {
        deleteNode(rhs->data, rhs);
    }

    if (rhs != 0 && rhs->data >= high) {
        deleteInterval(low, high, rhs->left);
    }

    if (rhs != 0 && rhs->data <= low) {
        deleteInterval(low, high, rhs->right);
    }
}

int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);

    BinarySearchTree bst;
    char order[100] = {0};
    int n1, n2;
    int times = 0;

    cin >> times;

    for (int i = 0; i < times; i++) {
        cin >> order;

        if (!strcmp(order, "insert")) {
            cin >> n1;
            bst.insert(n1);
        } else if (!strcmp(order, "delete")) {
            cin >> n1;
            bst.deleteEqual(n1);
        } else if (!strcmp(order, "delete_less_than")) {
            cin >> n1;
            bst.deleteLess(n1);
        } else if (!strcmp(order, "delete_greater_than")) {
            cin >> n1;
            bst.deleteGreater(n1);
        } else if (!strcmp(order, "delete_interval")) {
            cin >> n1 >> n2;
            bst.deleteInterval(n1, n2);
        } else if (!strcmp(order, "find")) {
            cin >> n1;
            bool flag = bst.find(n1);
            if (flag == 1) {
                cout << "Y" << '\n';
            } else {
                cout << "N" << '\n';
            }
        } else if (!strcmp(order, "find_ith")) {
            cin >> n1;
            bst.find_ith(n1);
            if (!flag) {
                cout << "N\n";
            }
        }
    }

    return 0;
}

skyzh's solution

#include <iostream>
#include <climits>
#include <cstring>

using namespace std;

template<typename T>
struct BST {
    struct Node {
        T x;
        Node *l, *r;

        Node(Node *l = nullptr, Node *r = nullptr) : l(l), r(r) {}

        Node(const T &x, Node *l = nullptr, Node *r = nullptr) : x(x), l(l), r(r) {}

        void debug(int depth = 0) {
            return;
            for (int i = 0; i < depth; i++) cout << " ";
            cout << x << endl;
            for (int i = 0; i < depth; i++) cout << " ";
            cout << "L" << endl;
            if (l) l->debug(depth + 1);
            for (int i = 0; i < depth; i++) cout << " ";
            cout << "R" << endl;
            if (r) r->debug(depth + 1);
        }
    } *root;

    BST() : root(nullptr) {}

    void clear(Node *ptr) {
        if (!ptr) return;
        clear(ptr->l);
        clear(ptr->r);
        delete ptr;
    }

    bool find(Node *ptr, const T &x) {
        if (!ptr) return false;
        if (ptr->x == x) return true;
        return find(ptr->l, x) || find(ptr->r, x);
    }

    bool find(const T &x) {
        return find(root, x);
    }

    Node *insert(Node *ptr, const T &x) {
        if (!ptr) return new Node(x);
        if (x <= ptr->x) ptr->l = insert(ptr->l, x);
        if (x > ptr->x) ptr->r = insert(ptr->r, x);
        return ptr;
    }

    void insert(const T &x) {
        root = insert(root, x);
    }

    Node *find_ith(Node *ptr, int &i) {
        if (!ptr) return nullptr;
        Node *l = find_ith(ptr->l, i);
        if (l) return l;
        if (i == 1) return ptr;
        --i;
        Node *r = find_ith(ptr->r, i);
        if (r) return r;
        return nullptr;
    }

    Node *find_ith(int i) {
        return find_ith(root, i);
    }

    void delete_less_than(const T &x) {
        delete_interval(INT_MIN, x);
    }

    void delete_greater_than(const T &x) {
        delete_interval(x, INT_MAX);
    }

    Node *delete_node_at(Node *ptr) {
        if (!ptr->l) {
            return ptr->r;
        } else {
            Node *prev = nullptr, *c = ptr->l;
            while (c->r) {
                prev = c;
                c = c->r;
            }
            if (!prev) {
                c->r = ptr->r;
                return c;
            }
            prev->r = delete_node_at(c);
            c->l = ptr->l;
            c->r = ptr->r;
            return c;
        }
    }

    Node *delete_node(Node *ptr, const T &x) {
        if (!ptr) return nullptr;
        if (x < ptr->x) ptr->l = delete_node(ptr->l, x);
        if (x == ptr->x) {
            Node *result = delete_node_at(ptr);
            delete ptr;
            return result;
        }
        if (x > ptr->x) ptr->r = delete_node(ptr->r, x);
        return ptr;
    }

    void delete_node(const T &x) {
        root = delete_node(root, x);
    }

    Node *delete_interval(Node *ptr, const T &x1, const T &x2, const T &t1, const T &t2) {
        if (!ptr) return nullptr;
        if (t1 <= x1 && x2 <= t2) {
            clear(ptr);
            return nullptr;
        }
        ptr->l = delete_interval(ptr->l, x1, ptr->x, t1, t2);
        ptr->r = delete_interval(ptr->r, ptr->x + 1, x2, t1, t2);
        if (t1 <= ptr->x && ptr->x <= t2) {
            Node *tmp = delete_node_at(ptr);
            delete ptr;
            return tmp;
        }
        return ptr;
    }

    Node *delete_interval(const T &t1, const T &t2) {
        root = delete_interval(root, INT_MIN, INT_MAX, t1, t2);
    }
};

int main() {
    /*
    "insert": 插入, 接下来一个整数, x, 表示被插入的元素

    "delete": 删除, 接下来一个整数, x, 表示被删除的元素(若树中有重复删除任意一个)

    "delete_less_than": 删除小于 x 的所有元素, 接下来一个整数, x

    "delete_greater_than": 删除大于 x 的所有元素, 接下来一个整数, x

    "delete_interval": 删除大于 a 且小于 b 的所有元素, 接下来两个整数, a, b

    "find": 查找, 接下来一个整数, x, 表示被查找的元素

    "find_ith": 查找第 i 小的元素, 接下来一个整数, i
    */

    char cmd[100];
    int N;
    int op1, op2;

    BST<int> tree;

    cin >> N;
    for (int i = 0; i < N; i++) {
        cin >> cmd;
        if (strcmp(cmd, "insert") == 0) {
            cin >> op1;
            tree.insert(op1);
            tree.root->debug();
        } else if (strcmp(cmd, "delete") == 0) {
            cin >> op1;
            tree.delete_node(op1);
            tree.root->debug();
        } else if (strcmp(cmd, "delete_less_than") == 0) {
            cin >> op1;
            tree.delete_less_than(op1 - 1);
            tree.root->debug();
        } else if (strcmp(cmd, "delete_greater_than") == 0) {
            cin >> op1;
            tree.delete_greater_than(op1 + 1);
            tree.root->debug();
        } else if (strcmp(cmd, "delete_interval") == 0) {
            cin >> op1 >> op2;
            tree.delete_interval(op1 + 1, op2 - 1);
            tree.root->debug();
        } else if (strcmp(cmd, "find") == 0) {
            cin >> op1;
            if (tree.find(op1)) cout << "Y" << endl; else cout << "N" << endl;
        } else if (strcmp(cmd, "find_ith") == 0) {
            cin >> op1;
            BST<int>::Node *ith = tree.find_ith(op1);
            if (ith) cout << ith->x << endl; else cout << "N" << endl;
        }
    }
    return 0;
}

yyong119's solution

#include <cstdio>
#include <cstring>

class SearchTree {
public:
    struct Node {
        Node(int data = 0x7fffffff) : data_(data), number_(1), l_son_(NULL), r_son_(NULL) {}
        void DeleteLSon() {
            if (l_son_) {
                l_son_->DeleteSon();
                delete l_son_;
                l_son_ = NULL;
            }
        }
        void DeleteRSon() {
            if (r_son_) {
                r_son_->DeleteSon();
                delete r_son_;
                r_son_ = NULL;
            }
        }
        void DeleteSon() {
            DeleteLSon();
            DeleteRSon();
        }
        int data_, number_;
        Node *l_son_, *r_son_;
    };

    SearchTree() : head_(new Node()) {}
    ~SearchTree() {
        head_->DeleteSon();
        delete head_;
    }

    void Insert(int data) {
        Node *p = head_, *q = head_->l_son_;
        while(q) {
            if (q->data_ == data) {
                ++q->number_;
                return;
            }
            p = q;
            if (q->data_ < data) q = q->r_son_;
            else q = q->l_son_;
        }
        if (p->data_ >= data) p->l_son_ = new Node(data);
        else p->r_son_ = new Node(data);
    }

    void Remove(int data) {
        Node *p = head_, *q = head_->l_son_;
        while(q && q->data_ != data) {
            p = q;
            if (q->data_ < data) q = q->r_son_;
            else q = q->l_son_;
        }
        Remove(p, q);
    }

    void Remove(Node *p, Node *q, bool total = 0) {
        if (q == NULL) return;
        if (!total && q->number_ > 1) {
            --q->number_;
            return;
        }
        if (q->l_son_ && q->r_son_) {
            Node *p2 = q, *q2 = q->l_son_;
            for (; q2->r_son_; p2 = q2, q2 = q2->r_son_);
            q->data_ = q2->data_;
            q->number_ = q2->number_;
            if (p2->l_son_ == q2) p2->l_son_ = q2->l_son_;
            else p2->r_son_ = q2->l_son_;
            delete q2;
        }
        else {
            Node *new_node = NULL;
            if (q->l_son_) 
                new_node = q->l_son_;
            else if (q->r_son_)
                new_node = q->r_son_;

            if (p->l_son_ == q)
                p->l_son_ = new_node;
            else
                p->r_son_ = new_node;
            delete q;
        }
    }
    void RemoveLessThan(int data) {
        RemoveLessThan(head_, head_->l_son_, data);
    }

    void RemoveLessThan(Node *p, Node *q, int data) {
        Node *tmp;
        while(q) {
            if (q->data_ <= data) {
                q->DeleteLSon();
                if (q->data_ == data) return;
                tmp = q;
                q = q->r_son_;
                if (p->l_son_ == tmp) p->l_son_ = q;
                    else p->r_son_ = q;
                delete tmp;
            }
            else {
                p = q;
                q = q->l_son_;
            }
        }
    }

    void RemoveGreaterThan(int data) {
        RemoveGreaterThan(head_, head_->l_son_, data);
    }

    void RemoveGreaterThan(Node *p, Node *q, int data) {
        Node *tmp;
        for (; q;) {
            if (q->data_ >= data) {
                q->DeleteRSon();
                if (q->data_ == data)
                    return;
                tmp = q;
                q = q->l_son_;
                if (p->l_son_ == tmp) p->l_son_ = q;
                    else p->r_son_ = q;
                delete tmp;
            }
            else {
                p = q;
                q = q->r_son_;
            }
        }
    }

    void RemoveInvterval(int lower, int upper) {
        Node *p = head_, *q = head_->l_son_;
        while(q) {
            if (q->data_ < lower) {
                p = q;
                q = q->r_son_;
            }
            else if (q->data_ > upper) {
                p = q;
                q = q->l_son_;
            }
            else {
                RemoveGreaterThan(q, q->l_son_, lower);
                RemoveLessThan(q, q->r_son_, upper);
                if (q->data_ == lower || q->data_ == upper) return;
                Remove(p, q, 1);
                return;
            }
        }
    }

    bool Find(int data) {
        for (Node *p = head_->l_son_; p;) {
            if (p->data_ == data) return true;
            else if (p->data_ < data)
                p = p->r_son_;
            else
                p = p->l_son_;
        }
        return false;
    }

    Node *FindIth(int i) {
        return FindIth(head_->l_son_, i);
    }

    Node *FindIth(Node *node, int &i) {
        if (!node) return NULL;
        Node *tmp = FindIth(node->l_son_, i);
        if (tmp) return tmp;
        if (node->number_ >= i) return node;
        i -= node->number_;
        return FindIth(node->r_son_, i);
    }

    void Output() {
        if (head_->l_son_) Output(head_->l_son_);
        printf("\n");
    }

    void Output(Node *p) {
        if (p->l_son_) Output(p->l_son_);
        printf("%d,%d,%d ", p, p->data_, p->number_);
        if (p->r_son_) Output(p->r_son_);
    }
    Node *head_;
};

SearchTree tree;
int total_command_number, tmp1, tmp2;
char command[30];
SearchTree::Node *tmp;

int main() {

    scanf("%d", &total_command_number);
    for (int i = 0; i < total_command_number; ++i) {
        scanf("%s%d", command, &tmp1);
        if (!strcmp(command, "insert"))
            tree.Insert(tmp1);
        else if (!strcmp(command, "delete"))
            tree.Remove(tmp1);
        else if (!strcmp(command, "delete_less_than"))
            tree.RemoveLessThan(tmp1);
        else if (!strcmp(command, "delete_greater_than"))
            tree.RemoveGreaterThan(tmp1);
        else if (!strcmp(command, "delete_interval")) {
            scanf("%d", &tmp2);
            tree.RemoveInvterval(tmp1, tmp2);
        }
        else if (!strcmp(command, "find"))
            printf("%c\n", tree.Find(tmp1) ? 'Y' : 'N');
        else if (!strcmp(command, "find_ith")) {
            tmp = tree.FindIth(tmp1);
            if (tmp) printf("%d\n", tmp->data_);
            else printf("N\n");
        }
    }
}

zqy2018's solution

/*
    See the editorial at https://github.com/zqy1018/tutorials/blob/master/bst_tutorial/bst.pdf
*/
#include <bits/stdc++.h>
#define INF 2000000000
using namespace std;
typedef long long ll;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
struct Tr {
    int siz, v, prio, lch, rch;
};
Tr tr[400005];
int S = 0, root = 0;
void maintain(int x){
    tr[x].siz = 1 + tr[tr[x].lch].siz + tr[tr[x].rch].siz;
}
int tree_new(int k){
    ++S;
    tr[S].siz = 1, tr[S].v = k, 
    tr[S].prio = rand(), 
    tr[S].lch = tr[S].rch = 0;
    return S;
}
struct pair_of_int{
    int x, y;
    pair_of_int(int _x, int _y): x(_x), y(_y){}
};
pair_of_int Split(int now, int k){
    if (!now) return pair_of_int(0, 0);
    else {
        int x, y;
        if (tr[now].v <= k){
            x = now;
            pair_of_int res = Split(tr[now].rch, k); 
            tr[now].rch = res.x;
            y = res.y;
        }else {
            y = now;
            pair_of_int res = Split(tr[now].lch, k);
            x = res.x;
            tr[now].lch = res.y;
        }
        maintain(now);
        return pair_of_int(x, y);
    }
}
pair_of_int Split_K(int now, int k){
    if (!now) return pair_of_int(0, 0);
    else {
        int x, y;
        if (k > tr[tr[now].lch].siz){
            x = now;
            pair_of_int res = Split_K(tr[now].rch, k - tr[tr[now].lch].siz - 1);
            tr[now].rch = res.x;
            y = res.y;
        }else {
            y = now;
            pair_of_int res = Split_K(tr[now].lch, k);
            x = res.x;
            tr[now].lch = res.y;
        }
        maintain(now);
        return pair_of_int(x, y);
    }
}
int Merge(int x, int y){
    if (!x || !y) return x + y;
    if (tr[x].prio < tr[y].prio){
        tr[x].rch = Merge(tr[x].rch, y);
        maintain(x);
        return x;
    }else{
        tr[y].lch = Merge(x, tr[y].lch);
        maintain(y);
        return y;
    }
} 
void Insert(int k){
    int z = tree_new(k);
    pair_of_int res = Split(root, k);
    root = Merge(Merge(res.x, z), res.y);
}
void Del(int k){
    pair_of_int res1 = Split(root, k - 1);
    pair_of_int res2 = Split_K(res1.y, 1);
    root = Merge(res1.x, res2.y);
}
bool Lookup(int k){
    int t = root;
    while (t){
        if (tr[t].v < k) t = tr[t].rch;
        else if (tr[t].v > k) t = tr[t].lch;
        else return true;
    }
    return false;
}
void Del_less(int k){
    pair_of_int res = Split(root, k - 1);
    root = res.y;
}
void Del_greater(int k){
    pair_of_int res = Split(root, k);
    root = res.x;
}
void Del_interval(int l, int r){
    int x, y, w, z;
    pair_of_int res1 = Split(root, l);
    pair_of_int res2 = Split(res1.y, r - 1);
    root = Merge(res1.x, res2.y);
}
bool Kth(int k, int &res){
    if(k <= 0 || k > tr[root].siz) return false;
    int t = root;
    for(; ; ){
        int evid = tr[tr[t].lch].siz;
        if (k <= evid) t = tr[t].lch;
        else if (k == evid + 1) {
            res = tr[t].v;
            break;
        }else k -= evid + 1, t = tr[t].rch;
    }
    return true;
}
int Q;
void init(){
    Q = read();
    srand(time(NULL));
}
void solve(){
    char o[30];
    while (Q--){
        scanf("%s", o);
        if (o[0] == 'i'){
            int k = read();
            Insert(k);
        }
        if (o[0] == 'd'){
            if (o[6] == '\0'){
                int k = read();
                if (Lookup(k))
                    Del(k);
            }else {
                if (o[7] == 'l'){
                    int k = read();
                    Del_less(k);
                }
                if (o[7] == 'g'){
                    int k = read();
                    Del_greater(k);
                }
                if (o[7] == 'i'){
                    int l = read(), r = read();
                    Del_interval(l, r);
                }
            }
        }
        if (o[0] == 'f'){
            if (o[4] == '\0'){
                int k = read();
                printf("%s\n", (Lookup(k) ? "Y": "N"));
            }else {
                int k = read(), res;
                if (Kth(k, res)) printf("%d\n", res);
                else printf("N\n");
            }
        }
    }
}
int main(){
    init();
    solve();
    return 0;
}

Zsi-r's solution

#include <iostream>
#include <cstring>

using namespace std;

struct bstnode
{
    int data;
    bstnode *left, *right;
    bstnode(int d, bstnode *l = NULL, bstnode *r = NULL) : data(d), left(l), right(r){};
    bstnode(){};
};

struct StNode{
    bstnode *node;
    int TimesPop;

    StNode(bstnode *N = NULL) : node(N), TimesPop(0){};
};

struct stacknode
{
    StNode data;
    stacknode *next;
    stacknode(const StNode d, stacknode *n = NULL):data(d),next(n){};
    stacknode():next(NULL){};
    ~stacknode(){};
};


class stack
{
private:
    stacknode *top_p;

public:
    stack() { top_p = NULL; }
    ~stack()
    {
        stacknode *temp = top_p;
        while(top_p!=NULL){
            top_p = top_p->next;
            delete temp;
            temp = top_p;
        }
    }
    bool isempty() { return top_p == NULL; }
    void push(const StNode&x)
    {
        top_p = new stacknode(x, top_p);
    }
    StNode pop()
    {
        stacknode *temp = top_p;
        StNode value = temp->data;
        top_p = top_p->next;
        delete temp;
        return value;
    }
};


class bst
{
public:
    bstnode *root;

    void insert(int d,bstnode* &n)
    {
        if (n == NULL)
            n = new bstnode(d);
        else if (n->data<d)
            insert(d, n->right);
        else if (d<=n->data)
            insert(d, n->left);
    }
    void remove(int d,bstnode* &n)
    {
        if (n==NULL)
            return;
        else if (n->data>d)
            remove(d, n->left);
        else if (n->data<d)
            remove(d, n->right);
        else if (n->left!=NULL&&n->right!=NULL)
        {
            bstnode *temp = n->right;
            while (temp->left!=NULL)
                temp = temp->left;
            n->data = temp->data;
            remove(n->data,n->right);
        }
        else
        {
            bstnode *temp = n;
            n = (n->left != NULL) ? n->left : n->right;
            delete temp;
        }
    }
    void find (int x,bstnode *n) const
    {
        if (n==NULL)
        {
            cout << 'N' << endl;
            return;
        }
        if(n->data==x)
        {
            cout << 'Y' << endl;
            return;
        }
        else if (n->data<x)
            return find(x, n->right);
        else if (n->data>x)
            return find(x, n->left);
    }

public:
    bst() { root = NULL; }
    ~bst(){};

    void find(int x)
    {
        find(x, root);
    }

    void find_ith(int i)
    {
        int count = 0;
        bool flag = false;
        stack s;
        StNode current(root);
        s.push(current);
        while(!s.isempty())
        {
            current = s.pop();
            if (++current.TimesPop==2){
                count++;
                if (count == i)
                {
                    flag = true;
                    cout << current.node->data << endl;
                    return;
                }
                if (current.node->right!=NULL)
                    s.push(StNode(current.node->right));
            }
            else
            {
                s.push(current);
                if (current.node->left!=NULL)
                    s.push(StNode(current.node->left));
            }
        }
        if (!flag)
            cout << 'N' << endl;
    }

    void delete_greater_than(int x,bstnode *&n)
    {   
        if (n==NULL)
            return;
        if (n->data<=x)
            delete_greater_than(x, n->right);
        else if (n->data>x)
        {
            delete_greater_than(x, n->right);
            delete_greater_than(x, n->left);
            remove(n->data, n);
        }
    }

    void delete_less_than(int x,bstnode *&n)
    {
        if (n==NULL)
            return;
        delete_less_than(x, n->left);
        delete_less_than(x, n->right);
        if (n->data<x)
            remove(n->data, n);
    }

    void delete_interval(int a ,int b,bstnode *&n)
    {
        if (n==NULL)
            return;
        delete_interval(a, b, n->left);
        delete_interval(a, b, n->right);
        if (n->data>a && n->data<b)
            remove(n->data, n);
    }
};

int main()
{
    int n,num,a,b;
    char s[20];
    bst tree;
    cin >> n;
    for (int i = 0; i < n;i++)
    {
        cin >> s;
        if(strcmp(s,"insert")==0){
            cin >> num;
            tree.insert(num, tree.root);
        }
        else if (strcmp(s,"delete")==0)
        {
            cin >> num;
            tree.remove(num, tree.root);
        }
        else if (strcmp(s,"find")==0)
        {
            cin >> num;
            tree.find(num);
        }
        else if  (strcmp(s,"find_ith")==0)
        {
            cin >> num;
            tree.find_ith(num);
        }
        else if (strcmp(s,"delete_greater_than")==0)
        {
            cin >> num;
            tree.delete_greater_than(num, tree.root);
        }
        else if (strcmp(s,"delete_less_than")==0)
        {
            cin >> num;
            tree.delete_less_than(num, tree.root);
        }
        else if (strcmp(s,"delete_interval")==0)
        {
            cin >> a >> b;
            tree.delete_interval(a, b, tree.root);
        }
    }
    return 0;
}