[ZJOI2019]语言 【线段树合并】

mac2024-10-24  55

LOJ3046


SOL

转换一下题意,对于每一个点我们维护可以到达的集合 S S S a n s = ∑ ∣ S i ∣ ans=\sum{|S_i|} ans=Si

然后每次区间覆盖相当于向一条链上的每个点的集合,集体覆盖上一些点。

线段树合并+树剖 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)

对于每一个点,我们用一颗动态开点的线段树维护可以到达的点集,当前点的贡献就线段树的总长。

对于每次操作,覆盖的点在一条链上,所以用树剖变成 l o g n logn logn个区间覆盖。

相当于要对链上的每一个点的线段树做logn次区间覆盖。 我们可以差分优化,即在 x , y x,y x,y + 1 +1 +1,区间覆盖, l c a , f a [ l c a ] lca,fa[lca] lca,fa[lca] − 1 -1 1,取消区间覆盖,对于子树线段树合并。

线段树合并 O ( n l o g n ) O(nlogn) O(nlogn),总共修改个数 n l o g n nlogn nlogn,修改复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)。 总复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)

线段树合并+转换贡献统计方式 O ( n l o g n ) O(nlogn) O(nlogn) 参照Sooke大佬的blog

对于点集大小的统计方式,我们不用树剖,而是记录每条路径的两个端点。因为两点之间路径唯一,所以可以用两端点还原出原路径。

对于每个点,最后我们会得到 2 k 2k 2k个端点,他们两两之间所覆盖的路径总长即为答案。

怎么求?我们把这些点按照dfs序从大到小排序,对于 i i i,贡献为 d e p [ i ] − d e p [ l c a ( i , i − 1 ) ] dep[i]-dep[lca(i,i-1)] dep[i]dep[lca(i,i1)]。 (注意最后要减去 d e p [ l c a ( f i r s t , l a s t ) ] dep[lca(first,last)] dep[lca(first,last)],即所有点的lca的深度,再加上lca本身)

如下图:

要支持修改、合并,我们可以用线段树分治求解。

每次单点修改, p u s h u p pushup pushup时把左区间dfs序最靠右的 A A A和右区间dfs序最靠左的 B B B l c a lca lca d e p dep dep减去即可。

总复杂度: O ( n l o g n ) O(nlogn) O(nlogn)


CODE

#include<bits/stdc++.h> using namespace std; #define sf scanf #define ri register int #define in red() #define gc getchar() #define cs const #define ll long long inline int red(){ int num=0,f=1;char c=gc; for(;!isdigit(c);c=gc)if(c=='-')f=-1; for(;isdigit(c);c=gc)num=(num<<1)+(num<<3)+(c^48); return num*f; } cs int N=1e5+10,M=2e6+10; vector<int> g[N]; int dep[N],dfn[N],tot=0,rev[N],fa[N],n,m; void dfs(int u){ dep[u]=dep[fa[u]]+1; dfn[u]=++tot; rev[tot]=u; for(ri i=g[u].size()-1;i>=0;--i){ int v=g[u][i]; if(v==fa[u])continue; fa[v]=u; dfs(v); } } namespace ST{ int dfn[N],tot=0,st[25][N<<1],Log[N<<1]; void dfs(int u){ st[0][++tot]=u;dfn[u]=tot; for(ri i=g[u].size()-1;i>=0;--i){ int v=g[u][i]; if(v==fa[u])continue; dfs(v); st[0][++tot]=u; } } inline int _min(int a,int b){return dfn[a]<dfn[b] ? a: b;} inline void init(){ dfs(1); for(ri i=2;i<=tot;++i)Log[i]=Log[i>>1]+1; for(ri i=1;i<=Log[tot];++i){ for(ri j=1;j+(1<<i)-1<=tot;++j){ st[i][j]=_min(st[i-1][j],st[i-1][j+(1<<(i-1))]); } } } inline int lca(int x,int y){ if(!x||!y)return 0; int fx=min(dfn[x],dfn[y]),fy=max(dfn[x],dfn[y]),k=Log[fy-fx+1]; return _min(st[k][fx],st[k][fy-(1<<k)+1]); } } using ST::lca; namespace SGT{ //线段树及合并 #define lc(x) ch[x][0] #define rc(x) ch[x][1] int ch[M][2],l[M],r[M],num[M],stk[M],top=0,tot=0,id[M]; ll sum[M]; inline void erase(int p){ stk[++top]=p; l[p]=r[p]=lc(p)=rc(p)=sum[p]=num[p]=id[p]=0; } inline int get(){ return top ? stk[top--] : ++tot; } inline void pushnow(int p,int k){ if(!p)return; int pre=num[p]; num[p]+=k; if(pre&&!num[p]){ l[p]=r[p]=sum[p]=0; } if(!pre&&num[p]){ l[p]=r[p]=id[p]; sum[p]=dep[id[p]]; } } inline void pushup(int p){ sum[p]=sum[lc(p)]+sum[rc(p)]-dep[lca(r[lc(p)],l[rc(p)])]; l[p]=l[lc(p)] ? l[lc(p)] : l[rc(p)]; r[p]=r[rc(p)] ? r[rc(p)] : r[lc(p)]; } inline void upt(int &p,int l,int r,int pos,int k){ if(!p)p=get(); if(l==r){ id[p]=rev[pos]; return pushnow(p,k); } int mid=(l+r)>>1; if(pos<=mid)upt(lc(p),l,mid,pos,k); else upt(rc(p),mid+1,r,pos,k); pushup(p); } inline ll query(int p){ return sum[p]-dep[lca(l[p],r[p])]+1; } inline void merge(int &p,int x,int y){ if(!x||!y)return p=x+y,void(); p=x; if(id[x])return num[x]+=num[y],erase(y),void(); merge(lc(p),lc(x),lc(y)); merge(rc(p),rc(x),rc(y)); pushup(p); erase(y); } } typedef pair<int,int> pi; vector<pi> h[N]; #define fi first #define se second int rt[N]; ll ans; void dfs2(int u){ for(ri i=g[u].size()-1;i>=0;--i){ int v=g[u][i]; if(v==fa[u])continue; dfs2(v); SGT::merge(rt[u],rt[u],rt[v]); } for(ri i=h[u].size()-1;i>=0;--i){ SGT::upt(rt[u],1,n,h[u][i].fi,h[u][i].se); } ll now=SGT::query(rt[u]); if(now)ans+=now-1; } signed main(){ n=in;m=in; for(ri i=1;i<n;++i){ int u=in,v=in; g[u].push_back(v); g[v].push_back(u); } dfs(1); ST::init(); for(ri i=1;i<=m;++i){ int x=in,y=in; int L=lca(x,y); h[x].push_back(pi(dfn[x],1));h[x].push_back(pi(dfn[y],1)); h[y].push_back(pi(dfn[x],1));h[y].push_back(pi(dfn[y],1)); h[L].push_back(pi(dfn[x],-1));h[L].push_back(pi(dfn[y],-1)); h[fa[L]].push_back(pi(dfn[x],-1));h[fa[L]].push_back(pi(dfn[y],-1)); } dfs2(1); cout<<ans/2; return 0; }
最新回复(0)