Skip to content

14321: 【原4321】图上点对统计

题目

题目描述

author: smallling 原OJ链接:https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4321

Description

给定一张$n$个点$m$条边的无向带权图,保证图联通。定义一条路径的花费为这条路径上边权最大的边的权值。统计图上不同的节点对的数量,满足它们之间所有路径的最小花费恰好为$k$。

Input Format

第一行三个正整数$n,m,k$。 接下来$m$行,每行三个正整数$x,y,v$,表示节点$x,y$之间有一条权值为$v$的边。

$1 \leq n \leq 10^5$

$1 \leq m \leq 5*10^5$

$1 \leq v, k \leq 10^9$

Output Format

一行一个数字表示答案

Sample Input

5 4 3
1 2 1
1 3 2
1 4 3
1 5 4

Sample Output

3

Hint

满足要求的点对为$(1, 4), (3, 4), (2, 4)$

ligongzzz's solution

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

class node {
public:
    int code;
    vector<pair<int, int>> edges;
};

vector<node> nl, nlf;
vector<bool> visited;
vector<int> forest_num;

int dfs1(int pos, int k, int code) {
    nl[pos].code = code;
    visited[pos] = true;
    int ans = 1;
    for (auto edge : nl[pos].edges) {
        if (edge.second < k && !visited[edge.first]) {
            ans += dfs1(edge.first, k, code);
        }
    }
    return ans;
}

int dfs2(int pos, int& num) {
    visited[pos] = true;
    num += forest_num[pos];
    int ans = 0;
    for (auto edge : nlf[pos].edges) {
        if (!visited[edge.first]) {
            ans += forest_num[edge.first] * num;
            ans += dfs2(edge.first, num);
        }
    }
    return ans;
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);

    int n, m, k;
    cin >> n >> m >> k;

    nl.resize(n);
    vector<pair<int, int>> k_edge;
    for (int i = 0; i < m; ++i) {
        int x, y, v;
        cin >> x >> y >> v;
        --x, --y;
        nl[x].edges.emplace_back(make_pair(y, v));
        nl[y].edges.emplace_back(make_pair(x, v));
        if (v == k) {
            k_edge.emplace_back(make_pair(x, y));
        }
    }

    int code = 0;
    visited.resize(n, false);
    for (int i = 0; i < n; ++i) {
        if (visited[i]) {
            continue;
        }
        int ans = dfs1(i, k, code);
        forest_num.emplace_back(ans);
        ++code;
    }

    nlf.resize(code);
    for (auto p : k_edge) {
        nlf[nl[p.first].code].edges.emplace_back(make_pair(nl[p.second].code, k));
        nlf[nl[p.second].code].edges.emplace_back(make_pair(nl[p.first].code, k));
    }

    visited.clear();
    visited.resize(code, false);
    int ans = 0;
    for (int i = 0; i < code; ++i) {
        if (visited[i]) {
            continue;
        }
        int num = 0;
        ans += dfs2(i, num);
    }
    cout << ans;

    return 0;
}

zqy2018's solution

#include <bits/stdc++.h>
#define REP(temp, init_val, end_val) for (int temp = init_val; temp <= end_val; ++temp)
using namespace std;
typedef long long ll;
typedef pair<int, int> intpair;
int read(){
    int f = 1, x = 0;
    char c = getchar();
    while (c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
    while (c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
    return f * x; 
}
int n, m, k;
pair<int, pair<int, int> > e[500005];
int fa[100005], siz[100005];
int id[500005];
ll ans = 0;
int Find(int x){
    return (x == fa[x] ? x: (fa[x] = Find(fa[x])));
}
int Union(int x, int y){
    int u = Find(x), v = Find(y);
    if (u == v) return 0;
    if (siz[u] > siz[v]) fa[v] = u, siz[u] += siz[v];
    else fa[u] = v, siz[v] += siz[u];
    return 1;
}
bool cmp(int i, int j){
    return e[i].first < e[j].first;
}
void init(){
    n = read(), m = read(), k = read();
    REP(i, 1, m){
        int u = read(), v = read(), w = read();
        e[i].second.first = u, 
        e[i].second.second = v, 
        e[i].first = w; 
        id[i] = i;
    }
    sort(id + 1, id + m + 1, cmp);
    REP(i, 1, n)
        fa[i] = i, siz[i] = 1;
    int lft = n;
    for (int i = 1; i <= m && lft > 1; ++i){
        int w = e[id[i]].first;
        int u = e[id[i]].second.first, v = e[id[i]].second.second;
        if (Find(u) != Find(v)){
            int t1 = siz[fa[u]], t2 = siz[fa[v]];
            if (w == k){
                ans += 1ll * t1 * t2;
            }
            --lft, Union(u, v);
        }
    }
}
void solve(){
    printf("%lld\n", ans);
}
int main(){
    init();
    solve();
    return 0;
}