XJ contest1587「树」
牛客「路径计数机」
树上差分,LCA
题意
给你一棵 n 个点的树和两个整数 p,q,求满足以下条件的四元组 (a,b,c,d) 的个数:
- 1≤a,b,c,d≤n
- 点 a 到点 b 的距离为 p
- 点 c 到点 d 的距离为 q
- 不存在一个点,它既在 a 到 b 的路径上,又在 c 在 d 的路径上
注:数据范围
测试点编号 |
n |
p,q |
特殊性质 |
1 |
≤30 |
≤30 |
无 |
2 |
≤50 |
≤50 |
无 |
3,4 |
≤200 |
≤200 |
无 |
5 |
≤3000 |
=2 |
无 |
6 |
≤3000 |
≤3000 |
树是一条链 |
7 |
≤3000 |
≤3000 |
树随机生成 |
8,9,10 |
≤3000 |
≤3000 |
无 |
对于所有数据,满足 1≤n,p,q≤3000 。
思路
暴力骗分(比赛时人傻,只想到骗分)
首先能想到 O(n4) 暴力,枚举 a,c 然后两个DFS暴力找路径。对于树是一条链也是一道简单数学题,O(n) 即可,XJ的数据有点水,这样就能过 1~4,6 共五个点。
正解
看数据范围 n≤3000 ,能猜测到这大概是一个 O(n2) 的算法,可以考虑枚举 a,b 。它的贡献就等于 (所有长度为 q 的路径的数量-不合法的数量) 。前者O(n2) 枚举即可,考虑如何求后者: 设a,b的LCA 为u;c,d的LCA为 v
分成两类求:
第一种v在u的子树内:可以发现当且仅当v在a->b 这条路径上才有交(若v不在a->b这条路径上,那比v更深的点显然也不会在a->b的路径上)

Tarjan预处理每对点的LCA,复杂度 O(n2):可以用 f[u] 表示以 u 为 LCA 的长度为 q 的路径有几条。预处理 f[u] 直接 O(n2) 枚举 c,d 若路径长度为 q,f[LCA(c,d)]++ 即可。
对于一对 (a,b), v 在 u 的子树内不合法的数量即为 a 到 b 路径上的 f 之和。
第二种 v 在 u 的子树外:若有交显然会经过 u 到其父亲的那条边, 我们只要再算一个 g[u] 表示经过 u 到其父亲那条边的长度为 q 的路径数。这东西跟 f 差不多,枚举 c,d ,若路径长度为q ,则 c 到 d 的路径上所有边的 g 加一,可以树上差分。

So: 一对合法 (a,b) 的贡献为:
Sumf−g[LCA(a,b)]−u∈path(a,b)∑f[u]
Sumf 为所有 f 之和。最后那一坨 f 树上差分就行了。
总复杂度 O(n2)
code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
| #include <bits/stdc++.h> using namespace std; const int N = 3005; typedef long long ll; int head[N],nxt[N<<1],pnt[N<<1],E; int f[N],g[N],Sf,ts[N]; int faa[N],lca[N][N],dep[N],fa[N],n,p,q; bool vis[N];ll ans; void add(int u,int v) {E++;pnt[E]=v;nxt[E]=head[u];head[u]=E;} int findfa(int u) {return u==faa[u]?u:faa[u]=findfa(faa[u]);} void ut(int u,int v) {faa[findfa(u)]=findfa(v);} int dist(int u,int v) {return dep[u]+dep[v]-2*dep[lca[u][v]];} void tarjan(int u) { vis[u]=1; for(int i=head[u];i;i=nxt[i]) { int v=pnt[i]; if(vis[v]) continue; dep[v]=dep[u]+1;fa[v]=u; tarjan(v);ut(v,u); } for(int v=1;v<=n;v++) if(v!=u){ if(vis[v]) lca[v][u]=lca[u][v]=findfa(v); } } void dfs(int u) { Sf+=f[u]; for(int i=head[u];i;i=nxt[i]) { int v=pnt[i]; if(v==fa[u]) continue; dfs(v);g[u]+=g[v]; } } void init() { for(int i=1;i<=n;i++) faa[i]=i; for(int i=1;i<=n;i++) lca[i][i]=i; tarjan(1); for(int i=1;i<=n;i++) for(int j=1;j<=n;j++) if(dist(i,j)==q) { f[lca[i][j]]++; g[i]++,g[j]++; g[lca[i][j]]-=2; } dfs(1); } void solve(int u) { for(int i=head[u];i;i=nxt[i]) { int v=pnt[i]; if(v==fa[u]) continue; solve(v); ts[u]+=ts[v]; } ans-=1ll*ts[u]*f[u]; } int main() { scanf("%d%d%d",&n,&p,&q); for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); add(u,v),add(v,u); } init(); for(int i=1;i<=n;i++) { for(int j=1;j<=n;j++) if(dist(i,j)==p) { ans+=Sf-g[lca[i][j]]; ts[i]++,ts[j]++; ts[lca[i][j]]--; ts[fa[lca[i][j]]]--; } } solve(1); printf("%lld\n",ans); }
|