跳转至

Treap

模版
C++
namespace _Treap {
    template<class T>
    struct Info {
        int L, R, rnd, cnt, size;
        T val;
        Info() = default;
        Info(T val) : val(val) {
            L = R = 0;
            cnt = size = 1;
            rnd = rand();
        }
    };
    template<class T>
    struct Treap {
        vector<Info<T>> info;
        int tot, root;
        Treap() {
            init(0);
        }
        Treap(int N) {
            init(N);
        }
        void init(int N) {
            tot = root = 0;
            info.resize(N + 1);
        }
        void push(int p) {
            info[p].size = info[info[p].L].size + info[info[p].R].size + info[p].cnt;
        }
        void LRotate(int &p) {
            int temp = info[p].R;
            info[p].R = info[temp].L;
            info[temp].L = p;
            info[temp].size = info[p].size;
            push(p);
            p = temp;
        }
        void RRotate(int &p) {
            int temp = info[p].L;
            info[p].L = info[temp].R;
            info[temp].R = p;
            info[temp].size = info[p].size;
            push(p);
            p = temp;
        }
        void build(int &p, T x) {
            p = ++tot;
            assert(p < (int) info.size());
            info[p] = Info(x);
        }
        void insert(T x) {
            insert(root, x);
        }
        void insert(int &p, T x) {
            if (p == 0) {
                build(p, x);
                return;
            }
            info[p].size++;
            if (info[p].val == x) {
                info[p].cnt++;
            } else if (info[p].val < x) {
                insert(info[p].R, x);
                if (info[info[p].R].rnd < info[p].rnd) {
                    LRotate(p);
                }
            } else {
                insert(info[p].L, x);
                if (info[info[p].L].rnd < info[p].rnd) {
                    RRotate(p);
                }
            }
        }
        bool erase(T x) {
            return erase(root, x);
        }
        bool erase(int &p, T x) {
            if (p == 0) {
                return false;
            }
            if (info[p].val == x) {
                if (info[p].cnt > 1) {
                    info[p].cnt--;
                    info[p].size--;
                    return true;
                }
                if (info[p].L == 0 || info[p].R == 0) {
                    p = info[p].L + info[p].R;
                    return true;
                } else if (info[info[p].L].rnd < info[info[p].R].rnd) {
                    RRotate(p);
                    return erase(p, x);
                } else {
                    LRotate(p);
                    return erase(p, x);
                }
            } else if (info[p].val < x) {
                bool succ = erase(info[p].R, x);
                if (succ) info[p].size--;
                return succ;
            } else {
                bool succ = erase(info[p].L, x);
                if (succ) info[p].size--;
                return succ;
            }
        }
        int rank(T x) {
            return rank(root, x);
        }
        int rank(int p, T x) {
            if (p == 0) return 0;
            if (info[p].val == x) {
                return info[info[p].L].size + 1;
            } else if (info[p].val < x) {
                return info[info[p].L].size + info[p].cnt + rank(info[p].R, x);
            } else {
                return rank(info[p].L, x);
            }
        }
        T kth(int x) {
            return kth(root, x);
        }
        T kth(int p, int x) {
            if (p == 0) return 0;
            if (x <= info[info[p].L].size) {
                return kth(info[p].L, x);
            } else if (x > info[info[p].L].size + info[p].cnt) {
                return kth(info[p].R, x - info[info[p].L].size - info[p].cnt);
            } else {
                return info[p].val;
            }
        }
        T prev(T x) {
            int p = prev(root, x);
            p = max(p, 0);
            return info[p].val;
        }
        int prev(int p, T x) {
            if (p == 0) return -1;
            if (info[p].val < x) {
                int res = prev(info[p].R, x);
                if (res != -1) return res;
                return p;
            }
            return prev(info[p].L, x);
        }
        T next(T x) {
            return info[next(root, x)].val;
        }
        int next(int p, T x) {
            if (p == 0) return -1;
            if (info[p].val > x) {
                int res = next(info[p].L, x);
                if (res != -1) return res;
                return p;
            } else {
                return next(info[p].R, x);
            }
        }
        T& operator[](int p) {
            return info[p].val;
        }
    }; // Treap
    // insert(T x) 插入 1 个 x
    // erase(T x) 删除 1 个 x
    // rank(T x) 查找 x 的排名
    // kth(int x) 查找第 x 小
    // prev(T x) 查找 x 的前驱
    // next(T x) 查找 x 的后继
};
using _Treap::Treap;
C++
namespace _Treap {
    template<class T>
    struct Node {
        Node *child[2];
        T val;
        int rnd, cnt, size;
        Node(T val) : val(val) {
            cnt = size = 1;
            rnd = rand();
            child[0] = child[1] = nullptr;
        }
        void push() {
            size = cnt;
            if (child[0] != nullptr) {
                size += child[0] -> size;
            }
            if (child[1] != nullptr) {
                size += child[1] -> size;
            }
        }
    };
    template<class T>
    struct Treap {
        Node<T> *root = nullptr;
        pair<Node<T>*, Node<T>*> split(Node<T> *cur, T key) {
            if (cur == nullptr) return {nullptr, nullptr};
            if (cur -> val <= key) {
                auto temp = split(cur -> child[1], key);
                cur -> child[1] = temp.first;
                cur -> push();
                return {cur, temp.second};
            } else {
                auto temp = split(cur -> child[0], key);
                cur -> child[0] = temp.second;
                cur -> push();
                return {temp.first, cur};
            }
        }
        tuple<Node<T>*, Node<T>*, Node<T>*> splitByRank(Node<T> *cur, int rank) {
            if (cur == nullptr) return {nullptr, nullptr, nullptr};
            int ls_size = cur -> child[0] == nullptr ? 0 : cur -> child[0] -> size;
            if (rank <= ls_size) {
                Node<T> *L, *mid, *R;
                tie(L, mid, R) = splitByRank(cur -> child[0], rank);
                cur -> child[0] = R;
                cur -> push();
                return {L, mid, cur};
            } else if (rank <= ls_size + cur -> cnt) {
                Node<T> *lt = cur -> child[0], *rt = cur -> child[1];
                cur -> child[0] = cur -> child[1] = nullptr;
                return {lt, cur, rt};
            } else {
                Node<T> *L, *mid, *R;
                tie(L, mid, R) = splitByRank(cur -> child[1], rank - ls_size - cur -> cnt);
                cur -> child[1] = L;
                cur -> push();
                return {cur, mid, R};
            }
        }
        Node<T>* merge(Node<T> *a, Node<T> *b) {
            if (a == nullptr) return b;
            if (b == nullptr) return a;
            if (a -> rnd < b -> rnd) {
                a -> child[1] = merge(a -> child[1], b);
                a -> push();
                return a;
            } else {
                b -> child[0] = merge(a, b -> child[0]);
                b -> push();
                return b;
            }
        }
        void insert(T x) {
            auto temp = split(root, x);
            auto l_tr = split(temp.first, x - 1);
            Node<T> *res;
            if (l_tr.second == nullptr) {
                res = new Node(x);
            } else {
                l_tr.second -> cnt++;
                l_tr.second -> push();
            }
            Node<T> *l_tr_combined = merge(l_tr.first, l_tr.second == nullptr ? res : l_tr.second);
            root = merge(l_tr_combined, temp.second);
        }
        void erase(T x) {
            auto temp = split(root, x);
            auto l_tr = split(temp.first, x - 1);
            if (l_tr.second -> cnt > 1) {
                l_tr.second -> cnt--;
                l_tr.second -> push();
                l_tr.first = merge(l_tr.first, l_tr.second);
            } else {
                if (temp.first == l_tr.second) {
                    temp.first = nullptr;
                }
                delete l_tr.second;
                l_tr.second = nullptr;
            }
            root = merge(l_tr.first, temp.second);
        }
        int rank(T val) {
            return rank(root, val);
        }
        int rank(Node<T> *&cur, T val) {
            auto temp = split(cur, val - 1);
            int res = (temp.first == nullptr ? 0 : temp.first -> size) + 1;
            root = merge(temp.first, temp.second);
            return res;
        }
        T kth(int rank) {
            return kth(root, rank);
        }
        T kth(Node<T>* &cur, int rank) {
            Node<T> *L, *mid, *R;
            tie(L, mid, R) = splitByRank(cur, rank);
            T res = mid -> val;
            root = merge(merge(L, mid), R);
            return res;
        }
        T prev(T val) {
            auto temp = split(root, val - 1);
            T res = kth(temp.first, temp.first -> size);
            root = merge(temp.first, temp.second);
            return res;
        }
        T next(T val) {
            auto temp = split(root, val);
            int res = kth(temp.second, 1);
            root = merge(temp.first, temp.second);
            return res;
        }
    }; // Treap
    // insert(T x) 插入 1 个 x
    // erase(T x) 删除 1 个 x
    // rank(T x) 查找 x 的排名
    // kth(int x) 查找第 x 小
    // prev(T x) 查找 x 的前驱
    // next(T x) 查找 x 的后继
}
using _Treap::Treap;
例题

#105. 文艺平衡树 - LibreOJ

代码
C++
namespace _Treap {
    template<class T>
    struct Node {
        Node *child[2];
        T val;
        int rnd, cnt, size;
        bool rev;
        Node(const T& val) : val(val), cnt(1), size(1), rev(false) {
            rnd = rand();
            child[0] = child[1] = nullptr;
        }
        void push() {
            size = cnt;
            if (child[0] != nullptr) {
                size += child[0] -> size;
            }
            if (child[1] != nullptr) {
                size += child[1] -> size;
            }
        }
        void down() {
            swap(child[0], child[1]);
            if (child[0] != nullptr) {
                child[0] -> rev ^= 1;
            }
            if (child[1] != nullptr) {
                child[1] -> rev ^= 1;
            }
            rev = false;
        }
        void check() {
            if (rev) {
                down();
            }
        }
    };
    template<class T>
    struct Treap {
        Node<T> *root = nullptr;
        pair<Node<T>*, Node<T>*> split(Node<T> *cur, T key) {
            if (cur == nullptr) return {nullptr, nullptr};
            cur -> check();
            int sz = cur -> child[0] == nullptr ? 0 : cur -> child[0] -> size;
            if (sz < key) {
                auto temp = split(cur -> child[1], key - sz - 1);
                cur -> child[1] = temp.first;
                cur -> push();
                return {cur, temp.second};
            } else {
                auto temp = split(cur -> child[0], key);
                cur -> child[0] = temp.second;
                cur -> push();
                return {temp.first, cur};
            }
        }
        Node<T>* merge(Node<T> *a, Node<T> *b) {
            if (a == nullptr) return b;
            if (b == nullptr) return a;
            a -> check();
            b -> check();
            if (a -> rnd < b -> rnd) {
                a -> child[1] = merge(a -> child[1], b);
                a -> push();
                return a;
            } else {
                b -> child[0] = merge(a, b -> child[0]);
                b -> push();
                return b;
            }
        }
        void insert(T x) {
            auto temp = split(root, x);
            auto l_tr = split(temp.first, x - 1);
            Node<T> *res;
            if (l_tr.second == nullptr) {
                res = new Node(x);
            } else {
                l_tr.second -> cnt++;
                l_tr.second -> push();
            }
            Node<T> *l_tr_combined = merge(l_tr.first, l_tr.second == nullptr ? res : l_tr.second);
            root = merge(l_tr_combined, temp.second);
        }
        void seg_rev(int L, int R) {
            auto less = split(root, L - 1);
            auto more = split(less.second, R - L + 1);
            more.first -> rev = true;
            root = merge(less.first, merge(more.first, more.second));
        }
        void print() {
            print(root);
        }
        void print(Node<T>* cur) {
            if (cur == nullptr) {
                return;
            }
            cur -> check();
            print(cur -> child[0]);
            cout << cur -> val << " ";
            print(cur -> child[1]);
        }
    }; // Treap
}
using _Treap::Treap;