「牛客」神J上树
题意简述
给一棵 n 个点的树,每条边有边权 wi,定义 dis(u,v) 为 u 到 v 简单路的长度。神J可以在树上移动,且每次只能从一个节点 u 跳向其子孙节点 v ,且代价为 u×dis(u,v) 。给出 m 个询问,问是否能从 s 跳到 t ,若能,求最小代价。
n,m≤3×105,1≤wi≤107
题解
判断 s 是不是 t 的祖先,用dfn序即可。
再观察以下,显然最小代价是要每一个点跳到第一个比它小的点,最后再一步跳到 t 。
我们不妨先考虑在序列上怎么做: 处理每个点后第一个比它小的点可以用单调栈扫一遍,然后这东西可以倍增,预处理出从某个点跳 2j 到达点及这一段的代价。 在查答案时从 s 开始跳,跳到不越过 t 的最后的 u ,再手动计算 u 到 t 的代价即可。
现在是树上问题,于是想到树链剖分,将树剖成若干重链后,每条重链都可以按照序列上的问题预处理。
可以发现 ,s 到 t 的路径是若干重链(或重链的一部分)组成的。那么分成几种情况:
- 
在某条重链里跳:根序列上差不多,尽量向下跳,注意不要跳出在这条链上的结束位置
 
- 
跨链的跳:假设上次跳到的点为 lst ,不在该链上。 那就要找到该链上第一个比 lst 小的点,若当前链顶 u>lst ,则从 u 不断向下跳(跳到最后一个>lst 的点后再多跳一位),可以证明这样找到的是该链第一个比 lst 小的点,然后计算一下 lst 到 u 的代价即可。注意不能跳出界,当然也可能找不到比 lst 小的点,那就保存 lst 直接到下一条链即可。
 
- 
最后还没到 t :lst 到 t 手动计算即可
 
这题写起来真的毒瘤 ,主要是链之间跳细节挺多,详见代码
CODE
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108
   | #include <bits/stdc++.h> #define fi first #define se second using namespace std; const int N = 300010; const int M = 20; typedef long long ll; typedef pair<int,int> pii; int head[N],pnt[N<<1],nxt[N<<1],E,wth[N<<1]; int n,m,dep[N],fa[N],top[N],siz[N],son[N],id[N],bot[N],rk[N],tim,rid[N]; int jump[N][M];  ll dis[N],ss[N][M];  bool vis[N]; inline int red() {     int x=0;char ch=getchar();     while(ch<'0'||ch>'9') ch=getchar();     while(ch>='0'&&ch<='9') x=(x<<1)+(x<<3)+(ch^48),ch=getchar();     return x; } void add(int u,int v,int w) {     E++;pnt[E]=v;nxt[E]=head[u];wth[E]=w;head[u]=E; }
  void dfs1(int u) {     siz[u]=1;top[u]=bot[u]=u;     for(int i=head[u];i;i=nxt[i]) {         int v=pnt[i];         if(v==fa[u]) continue;         dep[v]=dep[u]+1;fa[v]=u;dis[v]=dis[u]+wth[i];dfs1(v);         siz[u]+=siz[v];         if(siz[son[u]]<siz[v]) son[u]=v;     } } void dfs2(int u) {     id[u]=++tim;rk[tim]=u;     if(son[u]) {         top[son[u]]=top[u];         dfs2(son[u]);         bot[u]=bot[son[u]];     }     for(int i=head[u];i;i=nxt[i]) {         if(pnt[i]==fa[u]||pnt[i]==son[u]) continue;         dfs2(pnt[i]);     }     rid[u]=tim; }
  int sta[N],tp; void init() {     for(int i=1;i<=n;i++)if(!vis[i]) {         int u=bot[i];tp=0;         while(u!=fa[top[i]]) {             vis[u]=1;             while(tp&&sta[tp]>u) tp--;             if(tp) jump[u][0]=sta[tp],ss[u][0]=(dis[sta[tp]]-dis[u])*(ll)u;             sta[++tp]=u;             u=fa[u];         }     }     for(int j=1;j<M;j++) {         for(int i=1;i<=n;i++) {             jump[i][j]=jump[jump[i][j-1]][j-1];             ss[i][j]=ss[i][j-1]+ss[jump[i][j-1]][j-1];         }     } }
  ll solve(int u,int v) {     vector<pii> a;     while(top[v]!=top[u]) {          a.push_back(pii(top[v],v));         v=fa[top[v]];     }     a.push_back(pii(u,v));     reverse(a.begin(),a.end());     int lst=u;ll ans=0;     for(int i=0;i<(int)a.size();i++) {          u=a[i].fi;                           for(int j=M-1;j>=0;j--) if(jump[u][j]&&id[jump[u][j]]<=id[a[i].se]&&jump[u][j]>lst) u=jump[u][j];                  if(u>lst) u=jump[u][0];                  if(u!=0&&id[u]<=id[a[i].se]&&u<lst) ans+=(dis[u]-dis[lst])*(ll)lst,lst=u;                  for(int j=M-1;j>=0;j--) {             if(jump[u][j]&&id[jump[u][j]]<=id[a[i].se]&&jump[u][j]<lst) ans+=ss[u][j],lst=u=jump[u][j];         }     }          ans+=(dis[a.back().se]-dis[lst])*(ll)lst;     return ans; } int main() {     n=red(),m=red();     for(int i=1;i<n;i++) {         int u=red(),v=red(),w=red();         add(u,v,w);add(v,u,w);     }     dfs1(1);dfs2(1);init();     while(m--) {         int s=red(),t=red();         if(id[t]<id[s]||id[t]>rid[s]) {printf("-1\n");continue;}         ll ans=solve(s,t);         printf("%lld\n",ans);     } }
   |