# 14321: 【原4321】图上点对统计

### 题目描述

author: smallling 原OJ链接：https://acm.sjtu.edu.cn/OnlineJudge-old/problem/4321

## Input Format

$1 \leq n \leq 10^5$

$1 \leq m \leq 5*10^5$

$1 \leq v, k \leq 10^9$

## Sample Input

5 4 3
1 2 1
1 3 2
1 4 3
1 5 4


## Sample Output

3


## ligongzzz's solution

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;

class node {
public:
int code;
vector<pair<int, int>> edges;
};

vector<node> nl, nlf;
vector<bool> visited;
vector<int> forest_num;

int dfs1(int pos, int k, int code) {
nl[pos].code = code;
visited[pos] = true;
int ans = 1;
for (auto edge : nl[pos].edges) {
if (edge.second < k && !visited[edge.first]) {
ans += dfs1(edge.first, k, code);
}
}
return ans;
}

int dfs2(int pos, int& num) {
visited[pos] = true;
num += forest_num[pos];
int ans = 0;
for (auto edge : nlf[pos].edges) {
if (!visited[edge.first]) {
ans += forest_num[edge.first] * num;
ans += dfs2(edge.first, num);
}
}
return ans;
}

int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);

int n, m, k;
cin >> n >> m >> k;

nl.resize(n);
vector<pair<int, int>> k_edge;
for (int i = 0; i < m; ++i) {
int x, y, v;
cin >> x >> y >> v;
--x, --y;
nl[x].edges.emplace_back(make_pair(y, v));
nl[y].edges.emplace_back(make_pair(x, v));
if (v == k) {
k_edge.emplace_back(make_pair(x, y));
}
}

int code = 0;
visited.resize(n, false);
for (int i = 0; i < n; ++i) {
if (visited[i]) {
continue;
}
int ans = dfs1(i, k, code);
forest_num.emplace_back(ans);
++code;
}

nlf.resize(code);
for (auto p : k_edge) {
nlf[nl[p.first].code].edges.emplace_back(make_pair(nl[p.second].code, k));
nlf[nl[p.second].code].edges.emplace_back(make_pair(nl[p.first].code, k));
}

visited.clear();
visited.resize(code, false);
int ans = 0;
for (int i = 0; i < code; ++i) {
if (visited[i]) {
continue;
}
int num = 0;
ans += dfs2(i, num);
}
cout << ans;

return 0;
}


## zqy2018's solution

#include <bits/stdc++.h>
#define REP(temp, init_val, end_val) for (int temp = init_val; temp <= end_val; ++temp)
using namespace std;
typedef long long ll;
typedef pair<int, int> intpair;
int f = 1, x = 0;
char c = getchar();
while (c < '0' || c > '9'){if(c == '-') f = -f; c = getchar();}
while (c >= '0' && c <= '9')x = x * 10 + c - '0', c = getchar();
return f * x;
}
int n, m, k;
pair<int, pair<int, int> > e[500005];
int fa[100005], siz[100005];
int id[500005];
ll ans = 0;
int Find(int x){
return (x == fa[x] ? x: (fa[x] = Find(fa[x])));
}
int Union(int x, int y){
int u = Find(x), v = Find(y);
if (u == v) return 0;
if (siz[u] > siz[v]) fa[v] = u, siz[u] += siz[v];
else fa[u] = v, siz[v] += siz[u];
return 1;
}
bool cmp(int i, int j){
return e[i].first < e[j].first;
}
void init(){
REP(i, 1, m){
e[i].second.first = u,
e[i].second.second = v,
e[i].first = w;
id[i] = i;
}
sort(id + 1, id + m + 1, cmp);
REP(i, 1, n)
fa[i] = i, siz[i] = 1;
int lft = n;
for (int i = 1; i <= m && lft > 1; ++i){
int w = e[id[i]].first;
int u = e[id[i]].second.first, v = e[id[i]].second.second;
if (Find(u) != Find(v)){
int t1 = siz[fa[u]], t2 = siz[fa[v]];
if (w == k){
ans += 1ll * t1 * t2;
}
--lft, Union(u, v);
}
}
}
void solve(){
printf("%lld\n", ans);
}
int main(){
init();
solve();
return 0;
}