跳转至

整式递推

例题

P6115 【模板】整式递推 - 洛谷

C++
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
int __OneWan_2024 = [](){
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    return 0;
}();
using i64 = long long;
using Int = int;
using Long = i64;
const Long P = 998244353;
const int LEN2 = 20;
const Long NTT_G = 3; // P 的原根
const Long NTT_I = 86583718; // 根号 P - 1
Long qpow(Long a, Long b) {
    Long res = 1;
    while (b) {
        if (b & 1) res = res * a % P;
        a = a * a % P;
        b >>= 1;
    }
    return res;
}
Long qpow(Long a, Long b, const Long P) {
    Long res = 1;
    while (b) {
        if (b & 1) res = res * a % P;
        a = a * a % P;
        b >>= 1;
    }
    return res;
}
Int add(Int a, const Int b) {
    a += b;
    if (a >= P) {
        a -= P;
    }
    return a;
}
Int add(Int a, const Int b, const Int P) {
    a += b;
    if (a >= P) {
        a -= P;
    }
    return a;
}
Int sub(Int a, const Int b) {
    a -= b;
    if (a < 0) {
        a += P;
    }
    return a;
}
Int sub(Int a, const Int b, const Int P) {
    a -= b;
    if (a < 0) {
        a += P;
    }
    return a;
}
template<typename T>
struct Poly : vector<T> {
    Poly() = default;
    Poly(int n) : vector<T>(n) {}
    Poly(int n, T x) : vector<T>(n, x) {}
    Poly(vector<T>::iterator s, vector<T>::iterator t) : vector<T>(s, t) {}
    Poly(vector<T>::const_iterator s, vector<T>::const_iterator t) : vector<T>(s, t) {}
    Poly(initializer_list<T> lst) : vector<T>(lst) {}
};
namespace Polynomial {
    int norm(int x) {
        return (1 << (__lg(x - 1) + 1));
    }
}
using namespace Polynomial;
namespace NTT {
    const Long g = NTT_G; // P 的原根
    const Long I = NTT_I; // 根号 P - 1
    Poly<Long> W;
    Poly<Long> __inv;
    Long inv(Long x) {
        if (__inv.empty()) {
            int len = 1 << LEN2;
            __inv.resize(len);
            __inv[0] = __inv[1] = 1;
            for (int i = 2 ; i < len ; i++) {
                __inv[i] = Long(P - P / i) * __inv[P % i] % P;
            }
        }
        if (x < (int) __inv.size()) {
            return __inv[x];
        }
        return qpow(x, P - 2);
    }
    Long inv(Long x, Long P) {
        return qpow(x, P - 2, P);
    }
    void DIF(Int *a, int n) {
        if (W.empty()) {
            int L = 1 << LEN2;
            W.resize(L);
            Int wn = qpow(g, P / L);
            W[L >> 1] = 1;
            for (int i = L / 2 + 1 ; i < L ; i++) {
                W[i] = Long(W[i - 1]) * wn % P;
            }
            for (int i = L / 2 - 1 ; i >= 1 ; i--) {
                W[i] = W[i << 1];
            }
        }
        for (int k = n >> 1 ; k ; k >>= 1) {
            for (int i = 0 ; i < n ; i += k << 1) {
                for (int j = 0 ; j < k ; j++) {
                    Int x = a[i + j], y = a[i + j + k];
                    a[i + j] = add(a[i + j], y);
                    a[i + j + k] = Long(sub(x, y)) * W[j + k] % P;
                }
            }
        }
    }
    void IDIT(Int *a, int n, bool flag = true) {
        for (int k = 1 ; k < n ; k <<= 1) {
            for (int i = 0 ; i < n ; i += k << 1) {
                for (int j = 0 ; j < k ; j++) {
                    Int x = a[i + j], y = Long(a[i + j + k]) * W[j + k] % P;
                    a[i + j] = add(x, y);
                    a[i + j + k] = sub(x, y);
                }
            }
        }
        if (flag) {
            Int inv = P - (P - 1) / n;
            for (int i = 0 ; i < n ; i++) {
                a[i] = Long(a[i]) * inv % P;
            }
        }
        reverse(a + 1, a + n);
    }
}
namespace recursive {
    int M, D; // 次数 和 阶数
    struct Matrix {
        int n;
        vector<Int> v;
        Matrix() = default;
        Matrix(int n) : n(n), v(n * n) {}
        friend Matrix operator*(const Matrix &x, const Matrix &y) {
            assert(x.n == y.n);
            int n = x.n;
            Matrix res(n);
            for (int i = 0 ; i < n ; i++) {
                int I = i * n;
                for (int j = 0 ; j < n ; j++) {
                    int J =  j * n;
                    for (int k = 0 ; k < n ; k++) {
                        res.v[I + k] = add(res.v[I + k], Long(x.v[I + j]) * y.v[J + k] % P);
                    }
                }
            }
            return res;
        }
        static Matrix E(int n) {
            Matrix e(n);
            for (int i = 0 ; i < n ; i++) {
                e.v[i * n + i] = 1;
            }
            return e;
        }
    };
    template<typename T>
    struct spn {
        size_t len;
        T *data;
        spn() = default;
        spn(T *data, size_t len) : len(len), data(data){}
        T& operator[](const int pos) {
            return data[pos];
        }
        size_t size() {
            return len;
        }
        spn<T> subspan(size_t start, size_t sublen) {
            return spn<T>(data + start, min(len - start, sublen));
        }
    };
    const int N = 1 << LEN2;
    Int fact[N + 1], invfact[N + 1];
    void init() {
        fact[0] = 1;
        for (int i = 1 ; i <= N ; i++) {
            fact[i] = Long(fact[i - 1]) * i % P;
        }
        invfact[N] = qpow(fact[N], P - 2);
        for (int i = N ; i >= 1 ; i--) {
            invfact[i - 1] = Long(invfact[i]) * i % P;
        }
    }
    bool XXXXXXX = false;
    void lagrange(spn<const Matrix> f, spn<const Int> g, spn<Matrix> outf, spn<Int> outg, int o) {
        static Int a[N], b[N], c[N], d[N];
        int n = f.size(), m = outf.size();
        while (o < n && m) {
            outf[0] = f[o];
            outg[0] = g[o];
            outf = outf.subspan(1, m - 1);
            outg = outg.subspan(1, m - 1);
            m--;
            o++;
        }
        if (m == 0) {
            return;
        }
        int L = norm(2 * n + m - 1);
        for (int i = 0 ; i < n ; i++) {
            a[i] = Long(invfact[i]) * invfact[n - 1 - i] % P;
            if ((n - 1 - i) & 1) {
                a[i] = sub(P, a[i]);
            }
        }
        for (int i = 0 ; i < n + m ; i++) {
            b[i] = NTT::inv(add(sub(o, n - 1), i));
        }
        fill(b + n + m, b + L, 0);
        c[0] = 1;
        for (int i = 0 ; i < n ; i++) {
            c[0] = Long(c[0]) * sub(o, i) % P;
        }
        for (int i = 1 ; i < m ; i++) {
            c[i] = Long(c[i - 1]) * b[i - 1] % P * add(o, i) % P;
        }
        NTT::DIF(b, L);
        int md = f[0].n;
        for (int p = 0 ; p < md ; p++) {
            for (int q = 0 ; q < md ; q++) {
                for (int i = 0 ; i < n ; i++) {
                    d[i] = Long(a[i]) * f[i].v[p * md + q] % P;
                }
                fill(d + n, d + L, 0);
                NTT::DIF(d, L);
                for (int i = 0 ; i < L ; i++) {
                    d[i] = Long(d[i]) * b[i] % P;
                }
                NTT::IDIT(d, L, false);
                for (int i = 0 ; i < m ; i++) {
                    outf[i].v[p * md + q] = Long(d[n - 1 + i]) * c[i] % P * NTT::inv(L) % P;
                }
            }
        }
        for (int i = 0 ; i < n ; i++) {
            d[i] = Long(a[i]) * g[i] % P;
        }
        fill(d + n, d + L, 0);
        NTT::DIF(d, L);
        for (int i = 0 ; i < L ; i++) {
            d[i] = Long(d[i]) * b[i] % P;
        }
        NTT::IDIT(d, L, false);
        for (int i = 0 ; i < m ; i++) {
            outg[i] = Long(d[n - 1 + i]) * c[i] % P * NTT::inv(L) % P;
        }
    }
    // a 是数列前几项
    // p 是每一项的系数多项式
    Int recursive(Long n, const vector<Int> &a, const vector<vector<Int>> &p) {
        int m = a.size();
        if (n < m) {
            return a[n];
        }
        int k = (int) p[0].size() - 1;
        auto eval = [&](const vector<Int> &p, Int x) {
            Int res = 0;
            for (int i = (int) p.size() - 1 ; i >= 0 ; i--) {
                res = add(Long(res) * x % P, p[i]);
            }
            return res;
        };
        auto get_m = [&](int i) -> Matrix {
            Matrix mi(m);
            Int res = eval(p[0], m + i);
            for (int j = 1 ; j < m ; j++) {
                mi.v[j * m + j - 1] = res;
            }
            for (int j = 0 ; j < m ; j++) {
                Int res = eval(p[m - j], m + i);
                mi.v[j * m + m - 1] = sub(0, res);
            }
            return mi;
        };
        Long np = n - m + 1;
        auto get_f = [&](int d, Int x) -> Matrix {
            Matrix f = get_m(x);
            for (int i = 1 ; i < d ; i++) {
                f = f * get_m(add(x, i));
            }
            return f;
        };
        auto get_g = [&](int d, Int x) -> Int {
            Int g = eval(p[0], add(m, x));
            for (int i = 1 ; i < d ; i++) {
                g = Long(g) * eval(p[0], add(i, add(m, x))) % P;
            }
            return g;
        };
        Int s = ceill(sqrtl((double) np / k));
        Int t = np / s;
        Int tl = s * k + 1;
        vector<Matrix> f(tl, m);
        vector<Int> g(tl);
        for (int i = 0 ; i <= k ; i++) {
            Int x = Long(i) * s % P;
            f[i] = get_m(x);
            g[i] = eval(p[0], add(m, x));
        }
        vector<Matrix> tf(tl, m);
        vector<Int> tg(tl);
        Int invS = NTT::inv(s);
        for (int i = __lg(s) - 1, d = 1 ; i >= 0 ; i--) {
            Int dk = d * k;
            lagrange(
                spn<const Matrix>(f.data(), dk + 1), spn<const Int>(g.data(), dk + 1),
                spn<Matrix>(f.data() + dk + 1, dk), spn<Int>(g.data() + dk + 1, dk),
                dk + 1
            );
            lagrange(
                spn<const Matrix>(f.data(), 2 * dk + 1), spn<const Int>(g.data(), 2 * dk + 1),
                spn<Matrix>(tf.data(), 2 * dk + 1), spn<Int>(tg.data(), 2 * dk + 1),
                Long(d) * invS % P
            );
            for (int j = 0 ; j <= 2 * dk ; j++) {
                f[j] = f[j] * tf[j];
                g[j] = Long(g[j]) * tg[j] % P;
            }
            d *= 2;
            if (!((s >> i) & 1)) continue;
            for (int j = d * k + 1 ; j <= (d + 1) * k ; j++) {
                Int x = Long(j) * s % P;
                f[j] = get_f(d, x);
                g[j] = get_g(d, x);
            }
            for (int j = 0 ; j <= (d + 1) * k ; j++) {
                Int x = Long(j) * s % P;
                f[j] = f[j] * get_m(add(d, x));
                g[j] = Long(g[j]) * eval(p[0], add(m, add(d, x))) % P;
            }
            d++;
        }
        Matrix M = Matrix::E(m);
        int D = 1;
        for (int i = 0 ; i < t ; i++) {
            M = M * f[i];
            D = Long(D) * g[i] % P;
        }
        for (Long i = Long(t) * s ; i < np ; i++) {
            M = M * get_m(i % P);
            D = Long(D) * eval(p[0], add(m, i % P)) % P;
        }
        Int ans = 0;
        for (int i = 0 ; i < m ; i++) {
            ans = add(ans, Long(a[i]) * M.v[i * m + m - 1] % P);
        }
        ans = Long(ans) * NTT::inv(D) % P;
        return ans;
    }
}
int main() {
    int n, m, d;
    cin >> n >> m >> d;
    vector<Int> a(m);
    for (int i = 0 ; i < m ; i++) {
        cin >> a[i];
    }
    vector<vector<Int>> p(m + 1);
    for (int i = 0 ; i <= m ; i++) {
        p[i].resize(d + 1);
        for (int j = 0 ; j <= d ; j++) {
            cin >> p[i][j];
        }
    }
    recursive::init();
    cout << recursive::recursive(n, a, p) << "\n";
    return 0;
}