Skip to content

14127: 【原4127】lca

题目

题目描述

author: Lmxyy 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4127

Description

dhh最近学了lca,可高兴了。他现在可以在\(O(logN)\)的时间里求两个点的lca了。

然而xyy却皮了一下说:“怕不是不支持换根。”

dhh想了想,觉得好像真的有点小解。但他嘴上却说:”哪里解了?“

xyy说:”那我来考考你,给你一棵\(N\)个点的无根树,每次询问一组\(u,v,w\),问以\(w\)为根的树,\(u\)和\(v\)这两个点的lca是什么?“

dhh一时脸红了。。。他答不上来,所以向你求助。

Input Format

第一行一个正整数\(N\),表示树的节点数。

接下来\(N-1\)行,每行两个整数\(a,b\),表示树中\(a\)点和\(b\)点之间有一条边相连。

接下来一行一个正整数\(Q\),表示询问的数目。

接下来\(Q\)行,每行三个整数\(u,v,w\)如题意。

Output Format

输出共\(Q\)行,每行一个数表示答案。

Sample Input

6
2 4
1 4
3 2
1 6
5 4
6
1 5 1
6 3 5
6 3 6
6 4 5
4 2 4
5 1 5

Sample Output

1
4
6
4
4
5

Data Range

对于\(40 \%\)的数据,\(N,Q \le 1000\)。

对于另外\(40 \%\)的数据,保证所有询问的\(w\)都是一样的。

对于\(100 \%\)的数据,\(N \le 100000\),\(Q \le 500000\),\(1 \le u,v,w \le N\)。

ligongzzz's solution

#include "iostream"
#include "cstdio"
#include "cmath"
#include "cstring"
using namespace std;

class edge {
public:
    int to = 0;
    edge* next = nullptr;
};

class tree_node {
public:
    int parent = 0;
    int depth = 0;
    int parent_list[24] = { 0 };
    edge* head = nullptr, * rear = nullptr;
};

tree_node tree_data[100009];
int root_pos = 0;

void dfs(int pos) {
    if (pos == root_pos) {
        tree_data[pos].depth = 0;
    }
    else {
        tree_data[pos].depth = tree_data[tree_data[pos].parent].depth + 1;
        int k = (int)(log(tree_data[pos].depth) / log(2));
        tree_data[pos].parent_list[0] = tree_data[pos].parent;
        for (int i = 1; i <= k; ++i) {
            tree_data[pos].parent_list[i] = tree_data[tree_data[pos].parent_list[i - 1]].parent_list[i - 1];
        }
    }
    for (auto p = tree_data[pos].head->next; p; p = p->next) {
        if (p->to != tree_data[pos].parent) {
            tree_data[p->to].parent = pos;
            dfs(p->to);
        }
    }
}

int LCA(int pos1, int pos2) {
    while (tree_data[pos1].depth < tree_data[pos2].depth) {
        int k = int(log(tree_data[pos2].depth - tree_data[pos1].depth) / log(2));
        pos2 = tree_data[pos2].parent_list[k];
    }
    while (tree_data[pos2].depth < tree_data[pos1].depth) {
        int k = int(log(tree_data[pos1].depth - tree_data[pos2].depth) / log(2));
        pos1 = tree_data[pos1].parent_list[k];
    }
    while (pos1 != pos2) {
        int k = int(log(tree_data[pos1].depth) / log(2));
        for (int i = k; i >= 0; --i) {
            if (i == 0 || tree_data[pos1].parent_list[i] != tree_data[pos2].parent_list[i]) {
                pos1 = tree_data[pos1].parent_list[i];
                pos2 = tree_data[pos2].parent_list[i];
            }
        }
    }
    return pos1;
}


int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int n;
    cin >> n;

    for (int i = 1; i <= n; ++i) {
        tree_data[i].rear = tree_data[i].head = new edge;
    }

    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;

        tree_data[u].rear->next = new edge;
        tree_data[u].rear = tree_data[u].rear->next;
        tree_data[u].rear->to = v;

        tree_data[v].rear->next = new edge;
        tree_data[v].rear = tree_data[v].rear->next;
        tree_data[v].rear->to = u;
    }

    root_pos = 1;
    dfs(1);

    int q;
    cin >> q;

    for (; q > 0; --q) {
        int u, v, w;
        cin >> u >> v >> w;
        int lca1 = LCA(u, v), lca2 = LCA(u, w), lca3 = LCA(v, w);
        if (tree_data[lca1].depth < tree_data[lca2].depth) {
            if (tree_data[lca3].depth > tree_data[lca2].depth) {
                cout << lca3 << "\n";
            }
            else {
                cout << lca2 << "\n";
            }
        }
        else {
            if (tree_data[lca3].depth > tree_data[lca1].depth) {
                cout << lca3 << "\n";
            }
            else {
                cout << lca1 << "\n";
            }
        }
    }

    return 0;
}

zqy2018's solution

/*
    See the solution at https://github.com/zqy1018/sjtu_oj_solutions/blob/master/solutions/sjtu4127.md
*/
#include <cstdio>
#define INF 2000000000
using namespace std;
typedef long long ll;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
int n, q, ans[500005] = {0}, tans[1500005];
int at[100005] = {0}, nxt[200005], to[200005], cnt = 0;
int at2[3000005] = {0}, nxt2[3000005], to2[3000005], cnt2 = 0;
int par[100005], dep[100005];
bool vis[100005] = {0}, typ[500005] = {0};
int Find(int x){
    return (par[x] == x ? x: (par[x] = Find(par[x])));
}
void Union(int xxxx, int yyyy){
    int u = Find(xxxx), v = Find(yyyy);
    if (u == v) return ;
    par[u] = v;
}
void init(){
    n = read();
    for (int i = 1; i < n; ++i){
        int u = read(), v = read();
        to[++cnt] = v, nxt[cnt] = at[u], at[u] = cnt;
        to[++cnt] = u, nxt[cnt] = at[v], at[v] = cnt;
    }
    q = read();
    for (int i = 1; i <= q; ++i){
        int u = read(), v = read(), w = read();
        to2[++cnt2] = v, nxt2[cnt2] = at2[u], at2[u] = cnt2;    // u v
        to2[++cnt2] = u, nxt2[cnt2] = at2[v], at2[v] = cnt2;    // v u
        to2[++cnt2] = u, nxt2[cnt2] = at2[w], at2[w] = cnt2;    // w u
        to2[++cnt2] = w, nxt2[cnt2] = at2[u], at2[u] = cnt2;    // u w
        to2[++cnt2] = w, nxt2[cnt2] = at2[v], at2[v] = cnt2;    // v w
        to2[++cnt2] = v, nxt2[cnt2] = at2[w], at2[w] = cnt2;    // w v
    }
}
void dfs(int cur, int fa){
    par[cur] = cur;
    for (int i = at[cur]; i; i = nxt[i]){
        int v = to[i];
        if (v == fa) continue;
        dep[v] = dep[cur] + 1, dfs(v, cur), Union(v, cur);
    }
    vis[cur] = true;
    for (int i = at2[cur]; i; i = nxt2[i]){
        int v = to2[i];
        if (i % 6 == 0){
            // as a root, determine the type
            int u = to2[i - 3], ccnt = 0;
            if (Find(u) == cur) ++ccnt;
            if (Find(v) == cur) ++ccnt;
            if (ccnt == 2) typ[i / 6] = false;  // both in w
            else if (ccnt == 1) ans[i / 6] = cur;
            else typ[i / 6] = true;

            if (vis[v])
                tans[i / 2] = Find(v);
        }else {
            if (!vis[v]) continue;
            int iid = (i / 6), idx = i % 6;
            if (idx <= 2)
                tans[iid * 3 + 1] = Find(v);
            else if (idx <= 4)  
                tans[iid * 3 + 2] = Find(v);
            else
                tans[iid * 3 + 3] = Find(v);
        }
    }
}
void solve(){
    dep[1] = 0;
    dfs(1, 0);
    for (int i = 1; i <= q; ++i){
        ans[i] = tans[i * 3 - 2];
        if (dep[ans[i]] < dep[tans[i * 3 - 1]])
            ans[i] = tans[i * 3 - 1];
        if (dep[ans[i]] < dep[tans[i * 3]])
            ans[i] = tans[i * 3];
    }
    for (int i = 1; i <= q; ++i)
        printf("%d\n", ans[i]);
}
int main(){
    init();
    solve();
    return 0;
}