跳转至

最近公共祖先

倍增

模版

预处理 \(O(n*\log n)\),查询 \(O(\log n)\)

代码
C++
#include <bits/stdc++.h>
using namespace std;
// 2024 OneWan
const int N = 500005;
vector<int> adj[N];
int fa[N][22], deep[N];
void dfs(int u, int p) {
    fa[u][0] = p;
    deep[u] = deep[p] + 1;
    for (int i = 1 ; i <= 21 ; i++) {
        fa[u][i] = fa[fa[u][i - 1]][i - 1];
    }
    for (auto &to : adj[u]) {
        if (to == p) continue;
        dfs(to, u);
    }
}
int LCA(int x, int y) {
    if (deep[x] < deep[y]) swap(x, y);
    for (int i = 21 ; i >= 0 ; i--) {
        if (deep[fa[x][i]] < deep[y]) continue;
        x = fa[x][i];
    }
    if (x == y) return x;
    for (int i = 21 ; i >= 0 ; i--) {
        if (fa[x][i] == fa[y][i]) continue;
        x = fa[x][i];
        y = fa[y][i];
    }
    return fa[x][0];
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n, m, rt;
    cin >> n >> m >> rt;
    for (int i = 1 ; i < n ; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(rt, 0);
    for (int i = 1 ; i <= m ; i++) {
        int x, y;
        cin >> x >> y;
        cout << LCA(x, y) << "\n";
    }
    return 0;
}

预处理 \(O(n*\log n)\),查询 \(O(1)\)

代码
C++
#include <bits/stdc++.h>
using namespace std;
// 2024 OneWan
const int N = 500005;
vector<int> adj[N];
int fa[N << 1][22], dfn[N], deep[N], tot, que[N << 1], lg[N << 1];
void dfs(int u, int p) {
    dfn[u] = ++tot;
    que[tot] = u;
    deep[u] = deep[p] + 1;
    for (auto &to : adj[u]) {
        if (to == p) continue;
        dfs(to, u);
        que[++tot] = u;
    }
}
void initLCA() {
    lg[0] = -1;
    for (int i = 1 ; i <= tot ; i++) {
        fa[i][0] = que[i];
        lg[i] = lg[i >> 1] + 1;
    }
    for (int j = 1 ; j <= 21 ; j++) {
        for (int i = 1 ; i + (1 << j) <= tot ; i++) {
            int x = fa[i][j - 1], y = fa[i + (1 << j - 1)][j - 1];
            fa[i][j] = deep[x] < deep[y] ? x : y;
        }
    }
}
int LCA(int x, int y) {
    if (dfn[x] > dfn[y]) swap(x, y);
    x = dfn[x];
    y = dfn[y];
    int k = lg[y - x + 1];
    x = fa[x][k];
    y = fa[y - (1 << k) + 1][k];
    return deep[x] < deep[y] ? x : y;
}
int main() {
    ios::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int n, m, rt;
    cin >> n >> m >> rt;
    for (int i = 1 ; i < n ; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }
    dfs(rt, 0);
    initLCA();
    for (int i = 1 ; i <= m ; i++) {
        int x, y;
        cin >> x >> y;
        cout << LCA(x, y) << "\n";
    }
    return 0;
}
例题