「牛客」排列计数机
题意
定义一个长为 k 的序列 A1,A2,…,Ak 的权值为:对于所有 1≤i≤k,max(A1,A2,…,Ai) 有多少种不同的取值。
给出一个 1 到 n 的排列 B1,B2,…,Bn,求 B 的所有非空子序列的权值的 m 次方之和。
答案对 109+7 取模。
数据范围: 1≤n≤105,1≤m≤20, 保证 B 是 1 到 n 的排列。
题解
这种序列上的问题多是DP。
暴力DP
首先先考虑一个比较暴力的DP
fv,j 表示最大值为 v ,序列的权值为 j 的方案数。按顺序枚举 i :
fAi,j=k∈[1,Ai−1]∑fk,j−1,j∈[2,i]fAi,1=1(1)
即:前面所有最大值小于 Ai 的序列末端加上 Ai ,会多一的权值; Ai 本身构成一种权值为 1 的序列。
fv,j=fv,j×2,v∈[Ai+1,n],j∈[1,i](2)
即:前面所有最大值小于 Ai 的序列末端加上 Ai ,权值不会改变,但方案数多了一倍。
统计答案:
ans=v=1∑nj=1∑nfv,j×jm
复杂度 O(n3)
正解
答案肯定形如 k1⋅1m+k2⋅2m+k3⋅3m+⋯+kn⋅nm 。其中 ki 表示权值为 i 的序列有多少个。
设 dpv,j 表示序列最大值为 v ,所有方案权值的 j 次方和 , (即 k1′⋅1j+k2′⋅2j+k3′⋅3j+⋯+kn′⋅nj )。并且用线段树来维护 dpv,j 。(具体得说:开 m 棵线段树维护 v 这维区间 j 次和)
参考上面提到的转移,对于转移 fv,j=fv,j×2,v∈[Ai+1,n],j∈[1,i] :
dpv,j←k1′×2⋅1j+k2′×2⋅2j+k3′×2⋅3j+⋯+kn′×2⋅nj=dpv,j×2v∈[Ai+1,n],j∈[0,m]
相当于在 m 颗线段树中把 [Ai+1,n] 的区间都乘二
对于转移 fAi,j=∑k∈[1,Ai−1]fk,j−1,j∈[2,i] ,意思就是所有最大值小于 Ai 的方案都可以更新 dpAi
先记 Sj=v∈[1,Ai−1]∑dpv,j 表示最大值小于 Ai 所有方案权值 j 次方和,按其含义也可以写成这样 Sj=k1⋅1j+k2⋅2j+⋯+kn⋅nj ,其中 ki 表构成权值为 i 的序列的方案数。现在用所有最大值小于 Ai 的方案更新 Ai ,每种方案序列的权值都会加一,相当于使 ki 表示权值为 i+1 的序列的方案数,即:
dpAi,j←k1⋅(1+1)j+k2⋅(2+1)j+⋯+kn⋅(n+1)j+1
(+1 是因为只有 Ai 也是一种序列方案且权值为 1),其中有形如 (x+1)j 的式子,根据二项式展开定理:
(x+1)j=(0j)xj+(1j)xj−1+⋯+(jj)x0(x+1)j=i=0∑j(j−ij)xi
可进行如下化简:
k1⋅(1+1)j+k2⋅(2+1)j+⋯+kn⋅(n+1)j+1=k1⋅i=0∑j(j−ij)1i+k2⋅i=0∑j(j−ij)2i+⋯+kn⋅i=0∑j(j−ij)ni=i=0∑j(j−ij)(k1⋅1i+k2⋅2i+⋯+kn⋅ni)=i=0∑j(j−ij)Si
总结一下,按顺序扫 Ai 。对于每个 Ai :
- m 颗线段树中 [Ai+1,n] 的部分全部乘二
- 先处理出 Sj=第j颗线段树中[1,Ai−1]的区间和 ,将第 j 棵树上的 [Ai,Ai] 改为 (k=0∑j(j−kj)Sk)+1
复杂度 O(nm2+nmlogn)
CODE
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73
| #include <bits/stdc++.h> using namespace std; typedef long long ll; const int N = 100010; const int M = 21; const ll mod = 1e9+7; int n,m,a[N]; ll sum[N<<2][M],tag[N<<2][M],C[M][M]; inline int red() { int r=0;char ch=getchar(); while(ch<'0'||ch>'9') ch=getchar(); while(ch>='0'&&ch<='9') r=(r<<1)+(r<<3)+(ch^48),ch=getchar(); return r; } void pushup(int x,int k) { sum[x][k]=(sum[x<<1][k]+sum[x<<1|1][k])%mod; } void updata(int x,int k,int vl) { (sum[x][k]*=vl)%=mod; (tag[x][k]*=vl)%=mod; } void pushdown(int x,int k) { updata(x<<1,k,tag[x][k]); updata(x<<1|1,k,tag[x][k]); tag[x][k]=1; } ll query(int p,int k,int x=1,int l=1,int r=n) { if(r<=p) return sum[x][k]; int mid=(l+r)>>1;pushdown(x,k); return (query(p,k,x<<1,l,mid)+(p>mid?query(p,k,x<<1|1,mid+1,r):0))%mod; } void multi(int p,int k,int x=1,int l=1,int r=n) { if(r<p)return; if(l>=p) {updata(x,k,2);return;} int mid=(l+r)>>1;pushdown(x,k); if(p<=mid) multi(p,k,x<<1,l,mid);multi(p,k,x<<1|1,mid+1,r); pushup(x,k); } void modify(int p,int k,int vl,int x=1,int l=1,int r=n) { if(l==r) {sum[x][k]=vl;return;} int mid=(l+r)>>1;pushdown(x,k); if(p<=mid) modify(p,k,vl,x<<1,l,mid);else modify(p,k,vl,x<<1|1,mid+1,r); pushup(x,k); } void init() { C[0][0]=1; for(int i=1;i<=m;i++) { C[i][0]=1; for(int j=1;j<=i;j++)C[i][j]=(C[i-1][j-1]+C[i-1][j])%mod; } for(int i=1;i<=(n<<2);i++) for(int j=0;j<=m;j++) tag[i][j]=1; } ll tmp[M],tp[M]; int main() { n=red(),m=red(); init(); for(int i=1;i<=n;i++) { a[i]=red(); memset(tp,0,sizeof(tp)); memset(tmp,0,sizeof(tmp)); for(int j=0;j<=m;j++) tp[j]=query(a[i],j); for(int j=0;j<=m;j++) { for(int k=0;k<=j;k++) { (tmp[j]+=C[j][j-k]*tp[k]%mod)%=mod; } } for(int j=0;j<=m;j++) { modify(a[i],j,tmp[j]+1); multi(a[i]+1,j); } } printf("%lld\n",query(n,m)); }
|