跳转至

康托展开

定义

求一个 \(n\) 排列的排位

正康托展开

模版
C++
using i64 = long long;
const i64 P = 998244353;
i64 fact[1000005]; // 需要预处理阶乘
i64 cantor(const vector<int> &p) {
    int n = p.size();
    i64 rk = 1;
    vector<int> cnt(n + 1);
    for (int i = 0 ; i < n ; i++) {
        cnt[p[i]]++;
    }
    for (int i = 0 ; i < n ; i++) {
        for (int j = 0 ; j < p[i] ; j++) {
            if (cnt[j] == 0) continue;
            rk = (rk + fact[n - 1 - i]) % P;
        }
        cnt[p[i]]--;
    }
    return rk;
}
C++
using i64 = long long;
const i64 P = 998244353;
i64 fact[1000005]; // 需要预处理阶乘
i64 qpow(i64 a, i64 b = P - 2, i64 res = 1) {
    while (b) {
        if (b & 1) res = res * a % P;
        a = a * a % P;
        b >>= 1;
    }
    return res;
}
i64 cantor(const vector<int> &p) {
    int n = p.size();
    i64 rk = 1;
    vector<int> cnt(n + 1);
    for (int i = 0 ; i < n ; i++) {
        cnt[p[i]]++;
    }
    i64 cur = 1;
    for (int i = 1 ; i <= n ; i++) {
        cur = cur * cnt[i] % P;
    }
    for (int i = 0 ; i < n ; i++) {
        for (int j = 0 ; j < p[i] ; j++) {
            if (cnt[j] == 0) continue;
            cur = cur * qpow(fact[cnt[j]]);
            cnt[j]--;
            cur = cur * fact[cnt[j]];
            rk = (rk + fact[n - 1 - i] * qpow(cur) % P) % P;
            cur = cur * qpow(fact[cnt[j]]);
            cnt[j]++;
            cur = cur * fact[cnt[j]];
        }
        cur = cur * qpow(fact[cnt[p[i]]]);
        cnt[p[i]]--;
        cur = cur * fact[cnt[p[i]]];
    }
    return rk;
}
C++
int lowbit(const int x) {
    return x & -x;
}
template<class T>
struct FenwickTree {
    vector<T> sum;
    int size;
    FenwickTree() {}
    FenwickTree(int n) {
        resize(n);
    }
    void resize(int n) {
        sum.resize(n + 1);
        size = n;
    }
    void clear() {
        sum.resize(0);
        sum.resize(size + 1);
    }
    T query(int x) {
        T res = 0;
        while (x) {
            res += sum[x];
            x -= lowbit(x);
        }
        return res;
    }
    T query(int L, int R) {
        return query(R) - query(L - 1);
    }
    void add(int x, T k) {
        while (x <= size) {
            sum[x] += k;
            x += lowbit(x);
        }
    }
}; // FenwickTree
using i64 = long long;
const i64 P = 998244353;
i64 fact[1000005]; // 需要预处理阶乘
i64 cantor(const vector<int> &p) {
    int n = p.size();
    i64 rk = 1;
    FenwickTree<int> cnt(n + 1);
    for (int i = 0 ; i < n ; i++) {
        cnt.add(p[i], 1);
    }
    for (int i = 0 ; i < n ; i++) {
        rk = (rk + fact[n - 1 - i] * cnt.query(p[i] - 1) % P) % P;
        cnt.add(p[i], -1);
    }
    return rk;
}

逆康托展开

模版(均为不可重复)
C++
using i64 = long long;
const i64 P = 998244353;
i64 fact[1000005]; // 需要预处理阶乘
vector<int> decantor(vector<int> p, i64 rk) {
    rk--;
    int n = p.size();
    vector<int> a(n), b(n);
    sort(begin(p), end(p));
    for (int i = 0 ; i < n ; i++) {
        a[i] = rk / fact[n - i - 1];
        rk %= fact[n - i - 1];
    }
    for (int i = 0 ; i < n ; i++) {
        b[i] = p[a[i]];
        p.erase(lower_bound(begin(p), end(p), b[i]));
    }
    return b;
}
C++
using i64 = long long;
const i64 P = 998244353;
i64 fact[1000005]; // 需要预处理阶乘
vector<int> decantor(const vector<int> &p, i64 rk) {
    rk--;
    int n = p.size();
    Treap<int> treap(n + 1);
    vector<int> a(n), b(n);
    for (int i = 0 ; i < n ; i++) {
        treap.insert(p[i]);
    }
    for (int i = 0 ; i < n ; i++) {
        a[i] = rk / fact[n - i - 1];
        rk %= fact[n - i - 1];
    }
    for (int i = 0 ; i < n ; i++) {
        b[i] = treap.kth(a[i] + 1);
        treap.erase(b[i]);
    }
    return b;
}
例题

P3014 [USACO11FEB] Cow Line S - 洛谷

查找第一个大于等于的排列

代码
C++
// 查找第一个大于等于 sum 的 f(x) 的排列
// cnt[i] 为 f(x) 中 i 的个数
vector<int> getUpper(i64 sum) {
    vector<int> num;
    i64 temp = sum;
    while (temp) {
        num.push_back(temp % 10);
        temp /= 10;
    }
    reverse(begin(num), end(num));
    array<int, 10> cur = cnt;
    int pos = -1; // 最长连续能变大的数位的下标
    for (int i = 0 ; i <= (int) num.size() ; i++) {
        if (i < (int) num.size()) {
            bool flag = false;
            for (int j = num[i] + 1 ; j <= 9 ; j++) {
                if (cur[j]) { // f(x) 排列中有比第 i 位大的数
                    flag = true;
                    break;
                }
            }
            if (flag) {
                pos = i; // 记录下标
            }
            // f(x) 排列中必须两个数都有才能替换
            if (cur[num[i]] == 0) break;
            cur[num[i]]--;
        } else {
            pos = i;
        }
    }
    if (pos == -1) return {};
    vector<int> res(num.size());
    cur = cnt;
    for (int i = 0 ; i < (int) num.size() ; i++) {
        if (i < pos) {
            res[i] = num[i];
        } else if (i == pos) {
            for (int j = num[i] + 1 ; j <= 9 ; j++) {
                if (cur[j]) {
                    res[i] = j;
                    break;
                }
            }
        } else {
            for (int j = 0 ; j <= 9 ; j++) {
                if (cur[j]) {
                    res[i] = j;
                    break;
                }
            }
        }
        cur[res[i]]--;
    }
    return res;
}

查找第一个小于等于的排列

代码
C++
// 查找第一个小于等于 sum 的 f(x) 的排列 
// cnt[i] 为 f(x) 中 i 的个数
vector<int> getLower(i64 sum) {
    vector<int> num;
    i64 temp = sum;
    while (temp) {
        num.push_back(temp % 10);
        temp /= 10;
    }
    reverse(begin(num), end(num));
    array<int, 10> cur = cnt;
    int pos = -1; // 最长连续能变小的数位的下标
    for (int i = 0 ; i <= (int) num.size() ; i++) {
        if (i < (int) num.size()) {
            bool flag = false;
            for (int j = num[i] - 1 ; j >= 0 ; j--) {
                if (cur[j]) { // f(x) 排列中有比第 i 位小的数
                    flag = true;
                    break;
                }
            }
            if (flag) {
                pos = i; // 记录下标
            }
            // f(x) 排列中必须两个数都有才能替换
            if (cur[num[i]] == 0) break;
            cur[num[i]]--;
        } else {
            pos = i;
        }
    }
    if (pos == -1) return {};
    vector<int> res(num.size());
    cur = cnt;
    for (int i = 0 ; i < (int) num.size() ; i++) {
        if (i < pos) {
            res[i] = num[i];
        } else if (i == pos) {
            for (int j = num[i] - 1 ; j >= 0 ; j--) {
                if (cur[j]) {
                    res[i] = j;
                    break;
                }
            }
        } else {
            for (int j = 9 ; j >= 0 ; j--) {
                if (cur[j]) {
                    res[i] = j;
                    break;
                }
            }
        }
        cur[res[i]]--;
    }
    return res;
}