1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
|
void solve() {
int n;
cin >> n;
vector adj(n, vector<int>());
vector<int> w(n);
for (int i = 0; i < n; i ++) {
int ww, u, v;
cin >> ww >> u >> v;
w[i] = ww;
u --, v --;
if (u > 0) {
adj[i].push_back(u);
adj[u].push_back(i);
}
if (v > 0) {
adj[i].push_back(v);
adj[v].push_back(i);
}
}
vector<int> siz(n);
vector<int> f(n);
function<void(int, int, int)> dfs = [&](int u, int fa, int d) {
siz[u] = w[u];
int p = adj[u].front();
for (auto &v : adj[u]) {
if (v != fa) {
dfs(v, u, d + 1);
siz[u] += siz[v];
}
}
f[0] += w[u] * d;
};
dfs(0, 0, 0);
int ans = 0x7f7f7f7f;
function<void(int, int)> dp = [&](int u, int fa) {
for (auto &v : adj[u]) {
if (v != fa) {
f[v] = f[u] + siz[0] - siz[v] * 2;
dp(v, u);
}
}
ans = min(ans, f[u]);
};
dp(0, 0);
cout << ans << "\n";
}
|