Skip to content

14244: 【原4244】社区送温暖

题目

题目描述

author: 泰玛什 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4244

Background

最后一次机考,助教想给同学们送温暖,于是就有了这道题。

Description

给定一颗以1为根节点的,节点总数为n的有根树。每个点i上有非负点权\(a_i\)。现在定义\(s_i\),\(s_i\)表示从第i个节点出发,一直走到根节点经过的所有点的点权加和(包括节点i和根节点)。

现在给你所有点的父亲节点编号,以及部分节点的\(s_i\)值,保证所有深度为偶数的节点的\(s_i\)值一定给出(根节点深度为0)。保证至少存在一种合法\(a_1,a_2,...,a_n\),满足给定的部分\(s_i\)值。请问所有可行方案中最小的\(\sum_{i=1}^n a_i\)是多少?

Input Format

第一行有1个整数n。

第二行有n-1个整数,第i个整数\(p_{i+1}\)表示i+1号点的父节点编号。

第三行有n个整数,第i个表示\(s_i\)。若为-1,表示这个点的\(s_i\)可以为任意值。

Output Format

输出一个整数,表示最小可能的\(\sum_{i=1}^n a_i\),保证一定有合法解。

Sample Input 1

5
1 1 1 1
1 -1 -1 -1 -1

Sample Output 1

1

Sample Input 2

5
1 2 3 1
1 -1 2 -1 -1

Sample Output 2

2

Subtasks

  • 对于30%的数据,满足\(s_i \geq 0 \)。
  • 对于100%的数据,满足\(2\leq n \leq 10^5, 1 \leq p_i < i, -1 \leq s_i \leq 10^9 \)。

ligongzzz's solution

#include "iostream"
#include "vector"
using namespace std;

class node {
public:
    vector<int> child;
    long long s = 0;
    long long a = 0;
    int parent = 0;
};

long long ans = 0;

void update(int pos, vector<node>& node_list) {
    if (pos == 1) {
        node_list[pos].a = node_list[pos].s;
        ans += node_list[pos].a;
    }
    else if (node_list[pos].s >= 0) {
        node_list[pos].a = node_list[pos].s - node_list[node_list[pos].parent].s;
        ans += node_list[pos].a;
    }
    else {
        if (node_list[pos].child.empty()) {
            node_list[pos].a = 0;
            ans += node_list[pos].a;
            return;
        }
        node_list[pos].s = node_list[node_list[pos].child[0]].s;
        for (auto p : node_list[pos].child) {
            node_list[pos].s = node_list[p].s < node_list[pos].s ? node_list[p].s : node_list[pos].s;
        }
        node_list[pos].a = node_list[pos].s - node_list[node_list[pos].parent].s;
        ans += node_list[pos].a;
    }
    for (auto p : node_list[pos].child) {
        update(p, node_list);
    }
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int n;
    cin >> n;
    vector<node> node_list(n + 1);

    for (int i = 2; i <= n; ++i) {
        int temp;
        cin >> temp;
        node_list[i].parent = temp;
        node_list[temp].child.push_back(i);
    }
    for (int i = 1; i <= n; ++i) {
        cin >> node_list[i].s;
    }

    update(1, node_list);
    cout << ans;

    return 0;
}