[EZOI][1215NOI模拟赛]树(树形DP+贪心+枚举)

§ 1 题意

给出三棵节点分别为 $n1,\ n2,\ n3$ 的树,再连两条边可构成一棵新树。

求可能形成的新树中,所有点对距离和的最大值。

$n1,\ n2,\ n3\leq 100000$。


§ 2 分析

对于其中一棵大小为 $n$ 的树,令 $dis_i$ 表示点 $i$ 到其他所有点的距离和,考虑如何快速求解 $dis_i$。

钦定点 $1$ 为根,不难发现 $dis_1=\sum\limits_{i\in V,\ i\neq 1}siz_i$,且有 $dis_i=dis_{fa_i}+n-2\,siz_i$。

所以可以先自底向上树形 DP 求出 $siz_i$,再自顶向下树形 DP 求出 $dis_i$。

考虑连一条边合并两棵树,要使所有点对距离和最大,一定贪心连接两棵树各自 $dis_i$ 最大的点。

所以尝试三种可能的合并方法,重新计算新生成的三棵树的 $dis_i$,与剩下的树合并取最大答案即可。

总时间复杂度为 $O(n1+n2+n3)$。


§ 3 代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;

inline void getint(int &num){
register int ch, neg = 0;
while(!isdigit(ch = getchar())) if(ch == '-') neg = 1;
num = ch & 15;
while(isdigit(ch = getchar())) num = num * 10 + (ch & 15);
if(neg) num = -num;
}

int n1, n2, n3;

struct Edge {int np; Edge *nxt;};

struct Tree{
int n, tope, siz[200005], mx; ll dis[200005];
Edge E[400005], *V[200005];

inline void addedge(const int &u, const int &v) {E[++tope].np = v, E[tope].nxt = V[u], V[u] = E + tope;}

void dfs1(int u, int fa){
siz[u] = 1;
for(register Edge *ne = V[u]; ne; ne = ne->nxt)
if(ne->np != fa) dfs1(ne->np, u), siz[u] += siz[ne->np];
if(u != 1) dis[1] += siz[u];
}

void dfs2(int u, int fa){
for(register Edge *ne = V[u]; ne; ne = ne->nxt)
if(ne->np != fa) dis[ne->np] = dis[u] + n - 2 * siz[ne->np], dfs2(ne->np, u);
}

inline void getmax(){
mx = 1;
for(register int i = 2; i <= n; i++)
if(dis[i] > dis[mx]) mx = i;
}
} T[6];

int main(){
getint(n1), getint(n2), getint(n3), T[0].n = n1, T[1].n = n2, T[2].n = n3;
T[3].n = n1 + n2, T[4].n = n1 + n3, T[5].n = n2 + n3;
for(register int i = 1; i < n1; i++){
register int u, v; getint(u), getint(v);
T[0].addedge(u, v), T[0].addedge(v, u);
T[3].addedge(u, v), T[3].addedge(v, u);
T[4].addedge(u, v), T[4].addedge(v, u);
}
for(register int i = 1; i < n2; i++){
register int u, v; getint(u), getint(v);
T[1].addedge(u, v), T[1].addedge(v, u);
T[3].addedge(u + n1, v + n1), T[3].addedge(v + n1, u + n1);
T[5].addedge(u, v), T[5].addedge(v, u);
}
for(register int i = 1; i < n3; i++){
register int u, v; getint(u), getint(v);
T[2].addedge(u, v), T[2].addedge(v, u);
T[4].addedge(u + n1, v + n1), T[4].addedge(v + n1, u + n1);
T[5].addedge(u + n2, v + n2), T[5].addedge(v + n2, u + n2);
}
for(register int i = 0; i <= 2; i++) T[i].dfs1(1, 0), T[i].dfs2(1, 0), T[i].getmax();
T[3].addedge(T[0].mx, T[1].mx + n1), T[3].addedge(T[1].mx + n1, T[0].mx);
T[4].addedge(T[0].mx, T[2].mx + n1), T[4].addedge(T[2].mx + n1, T[0].mx);
T[5].addedge(T[1].mx, T[2].mx + n2), T[5].addedge(T[2].mx + n2, T[1].mx);
for(register int i = 3; i <= 5; i++) T[i].dfs1(1, 0), T[i].dfs2(1, 0), T[i].getmax();
ll ans = 0;
for(register int i = 3, j = 2; i <= 5; i++, j--){
register ll cur = 0;
for(register int k = 1; k <= T[i].n; k++) cur += T[i].dis[k];
for(register int k = 1; k <= T[j].n; k++) cur += T[j].dis[k];
(cur >>= 1) += T[i].n * (ll)T[j].dis[T[j].mx] + T[i].dis[T[i].mx] * (ll)T[j].n + T[i].n * (ll)T[j].n;
ans = max(ans, cur);
}
printf("%lld\n", ans);
return 0;
}