Skip to content

14195: 【原4195】Forest

题目

题目描述

author: 侯不会 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4195

Description

经过一学期的勤奋学习,侯不会终于掌握了程序设计的知识,他现在会做这道题,但是他想考考你。
一棵有 \(n\) 个点的树,每个点都有权值。一条简单路径的权值为路径上所有点的权值之和,一棵树的直径为树上权值最大的简单路径。
现在,不断地删去树上的边,每删一次就会多一棵树,请计算出任意时刻所有树的直径的乘积。
这个数可能很大,你只需输出对 \(10^9+7\) 取模之后的结果。

Input Format

第一行,一个正整数 \(n\)。
下一行 \(n\) 个整数 \(a_i\),表示顶点的权值。
之后的 \(n-1\) 行,每行两个整数 \(u_i\) 和 \(v_i\),表示 \(u_i\) 和 \(v_i\)之间的有一条边,编号为 \(i\)。
再之后 \(n-1\) 行,每一行一个整数 \(k_j\),表示第 \(j\) 条被删除的边的编号。

Output Format

输出共 \(n\) 行,第 \(i\) 行输出删除 \(i-1\) 条边后的结果。

Sample Output

3  
1 2 3  
1 2   
1 3   
2  
1

Sample Output

6  
9  
6

Data Range

对于100%的数据,\(n <= 10^5, a_i <= 10^4\)

zqy2018's solution

/*
    See the solution at https://github.com/zqy1018/sjtu_oj_solutions/blob/master/solutions/sjtu4195.md
*/
#include <bits/stdc++.h>
#define INF 2000000000
#define M 1000000007
using namespace std;
typedef long long ll;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while(c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while(c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
int n, a[100005], id[100005];
int to[200005], at[100005], nxt[200005], cnt = 0;
int ans[100005];
int zjdd[100005][3];        // 直径端点不会表达
int fa[100005][20] = {0}, dep[100005] = {0};
int sum[100005] = {0};
int fa2[100005];
int Find(int x){
    return (fa2[x] == x ? x: (fa2[x] = Find(fa2[x])));
}
void Union(int x, int y){
    int u = Find(x), v = Find(y);
    if (u == v) return ;
    fa2[v] = u;
}
int poww(int a, int b){
    int res = 1;
    while (b > 0){
        if (b & 1) res = 1ll * res * a % M;
        a = 1ll * a * a % M, b >>= 1;
    }
    return res;
}
void dfs(int cur, int ffa){
    for (int i = at[cur]; i; i = nxt[i]){
        int v = to[i];
        if (v == ffa) continue;
        fa[v][0] = cur, sum[v] = sum[cur] + a[v];
        dep[v] = dep[cur] + 1;
        dfs(v, cur);
    }
}
int lca(int x, int y){
    if (dep[x] != dep[y]){
        if (dep[x] < dep[y]) swap(x, y);
        int diff = dep[x] - dep[y];
        for (int t = 1, p = 0; diff > 0; t <<= 1, ++p)
            if (t & diff)
                x = fa[x][p], diff -= t; 
    }
    if (x == y) return x;
    for (int j = 19; j >= 0; --j)
        if (fa[x][j] != fa[y][j])
            x = fa[x][j], y = fa[y][j];
    return fa[x][0]; 
}
int dis(int x, int y){
    int llca = lca(x, y);
    return sum[x] + sum[y] - 2 * sum[llca] + a[llca];
}
void init(){
    n = read();
    for (int i = 1; i <= n; ++i)
        a[i] = read();
    for (int i = 1; i < n; ++i){
        int u = read(), v = read();
        to[++cnt] = v, nxt[cnt] = at[u], at[u] = cnt;
        to[++cnt] = u, nxt[cnt] = at[v], at[v] = cnt;
    }
    for (int i = 1; i < n; ++i)
        id[i] = read();

    sum[1] = a[1], dep[1] = 1, dfs(1, 0);
    for (int j = 1; j < 20; ++j){
        bool flag = false;
        for (int i = 1; i <= n; ++i)
            fa[i][j] = fa[fa[i][j - 1]][j - 1], 
            flag = (flag || fa[i][j]);
        if (!flag) break;
    }
}
void solve(){
    int aans = 1;
    for (int i = 1; i <= n; ++i)
        zjdd[i][0] = zjdd[i][1] = i, zjdd[i][2] = a[i], 
        fa2[i] = i, 
        aans = 1ll * aans * a[i] % M;

    ans[n] = aans;
    for (int i = n - 1; i >= 1; --i){
        int u = Find(to[id[i] << 1]), v = Find(to[(id[i] << 1) - 1]);
        int res = INT_MIN, ans1, ans2;
        for (int uu = 0; uu < 2; ++uu)
            for (int vv = 0; vv < 2; ++vv){
                int tmp = dis(zjdd[u][uu], zjdd[v][vv]);
                // cout << " " << tmp << " " <<  zjdd[u][uu] << " " << zjdd[v][vv] << endl;
                if (tmp > res)
                    res = tmp, ans1 = zjdd[u][uu], ans2 = zjdd[v][vv];
            }
        if (zjdd[u][2] > res) res = zjdd[u][2], ans1 = zjdd[u][0], ans2 = zjdd[u][1];
        if (zjdd[v][2] > res) res = zjdd[v][2], ans1 = zjdd[v][0], ans2 = zjdd[v][1];
        // cout << res << endl;
        aans = 1ll * aans * poww(zjdd[u][2], M - 2) % M;
        aans = 1ll * aans * poww(zjdd[v][2], M - 2) % M;
        Union(u, v);
        zjdd[u][0] = ans1, zjdd[u][1] = ans2, zjdd[u][2] = res;
        aans = 1ll * aans * res % M;
        ans[i] = aans;
    }
    for (int i = 1; i <= n; ++i)
        printf("%d\n", ans[i]);
}
int main(){
    init();
    solve();
    return 0;
}