LCA和RMQ算法

mac2024-08-14  72

 

RMQ(Range Minimum/Maximum Query),即区间最值查询,是指这样一个问题:对于长度为n的数列A,回答若干次询问RMQ(i,j),返回数列A中下标在区间[i,j]中的最小/大值。

本文介绍一种比较高效的ST算法解决这个问题。ST(Sparse Table)算法可以在O(nlogn)时间内进行预处理,然后在O(1)时间内回答每个查询。

1)预处理

设A[i]是要求区间最值的数列,F[i, j]表示从第i个数起连续2^j个数中的最大值。(DP的状态)

例如:

A数列为:3 2 4 5 6 8 1 2 9 7

F[1,0]表示第1个数起,长度为2^0=1的最大值,其实就是3这个数。同理 F[1,1] = max(3,2) = 3, F[1,2]=max(3,2,4,5) = 5,F[1,3] = max(3,2,4,5,6,8,1,2) = 8;

并且我们可以容易的看出F[i,0]就等于A[i]。(DP的初始值)

我们把F[i,j]平均分成两段(因为F[i,j]一定是偶数个数字),从 i 到i + 2 ^ (j - 1) - 1为一段,i + 2 ^ (j - 1)到i + 2 ^ j - 1为一段(长度都为2 ^ (j - 1))。于是我们得到了状态转移方程F[i, j]=max(F[i,j-1], F[i + 2^(j-1),j-1])。

2)查询

假如我们需要查询的区间为(i,j),那么我们需要找到覆盖这个闭区间(左边界取i,右边界取j)的最小幂(可以重复,比如查询1,2,3,4,5,我们可以查询1234和2345)。

因为这个区间的长度为j - i + 1,所以我们可以取k=log2( j - i + 1),则有:RMQ(i, j)=max{F[i , k], F[ j - 2 ^ k + 1, k]}。

举例说明,要求区间[1,5]的最大值,k = log2(5 - 1 + 1)= 2,即求max(F[1, 2],F[5 - 2 ^ 2 + 1, 2])=max(F[1, 2],F[2, 2]);

void ST(int n) { for (int i = 1; i <= n; i++) dp[i][0] = A[i]; for (int j = 1; (1 << j) <= n; j++) { for (int i = 1; i + (1 << j) - 1 <= n; i++) { dp[i][j] = max(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]); } } } int RMQ(int l, int r) { int k = 0; while ((1 << (k + 1)) <= r - l + 1) k++; return max(dp[l][k], dp[r - (1 << k) + 1][k]); }

 

LCA(树上最近公共祖先)

      用上面的算法的话时间复杂度是 O(nlog⁡n) 预处理,O(1) 在线查询。

      首先引入dfs序。

  

  例如,上图这棵树的一个dfs序为 8,5,9,5,8,4,6,15,6,7,6,4,10,11,10,16,3,16,12,16,10,2,10,4,8,1,14,1,13,1,8

9-6的路径上所有的点在上面的数组中可以找到一个连续数列(9 5 8 4 6),其中出现的深度最小的节点就是他们的LCA。

a:从树的根开始,将树看成一个无向图进行深度优先遍历,记录下每次到达的顶点,第一个顶点为树根root,

    每经过一条边都记录它的端点,每条边都恰好经过两次.用数组ver记录结点。

b:记录first数组和deep数组,first数组记录在深度优先遍历时结点第一次出现的位置。deep数组记录结点的深度

void dfs(int u,int dep) { vis[u]=true; ver[++tot]=u; first[u]=tot; deep[tot]=dep; for(int i=head[u];i!=-1;i=edge[i].next) { int v=edge[i].to; if(!vis[v]) { dfs(v,dep+1); ver[++tot]=u; deep[tot]=dep; } } }

可以发现,我们通过dfs记录结点后,当我们要查询结点u,v时,我们可以在结点的数组中找到u结点第一次出现的位置first[u]   和v结点第一次出现的位置 first[v],而他们位置之间的结点便是u到v的DFS顺序,虽然其中可能包含u或v的后代,但其中深度最小的还是u和v的最近公共祖先。因此可以用ST表记录与

结点数组相对应的深度序列的区间最小值下标,将lca转化为RMQ问题。

void ST(int n) { for(int i=1;i<=n;i++) dp[i][0]=i; for(int j=1;(1<<j)<=n;j++) for(int i=1;i+(1<<j)-1<=n;i++) { int a=dp[i][j-1];int b=dp[i+(1<<(j-1))][j-1]; //记录其中结点深度最小的结点的位置 dp[i][j]=deep[a]<deep[b]?a:b; } }

寻找LCA(u,v)时,先寻找first[u],first[v],将[first[u],first[v]]间的最小值的deep找出,该值下标所对应的结点即为LCA(u,v)。

即当first[u]>first[v]时,LCA(T,u,v)=RMQ(deep,R[v],R[u]),否则LCA(T,u,v) = RMQ(deep,R[u],R[v]).

int RMQ(int l,int r) { int k=0; while(1<<(k+1)<=r-l+1) k++; int a=dp[l][k],b=dp[r-(1<<k)+1][k]; return deep[a]<deep[b]?a:b; } int LCA(int u,int v) { int x=first[u],y=first[v]; if(x>y)swap(x,y); int res=RMQ(x,y); return ver[res]; }

 总结:

#include<iostream> #include<cstring> #include<cstdio> using namespace std; const int MAX=10009; int T,n,a,b; int head[MAX],cnt=0; int tot=0; int dp[MAX*2][25]; //ST表 int deep[MAX*2]; //记录节点深度 int ver[MAX*2]; //记录节点编号 int first[MAX]; //记录点第一次出现的位置 bool vis[MAX]; bool isroot[MAX]; //判断根节点的数组 struct Edge{ int to,next; }edge[MAX*2]; inline void add(int u,int v) { edge[cnt].to=v; edge[cnt].next=head[u]; head[u]=cnt++; } void dfs(int u,int dep) { vis[u]=true; //访问过该节点 ver[++tot]=u; //将该节点记录在ver中 first[u]=tot; //记录结点u第一次出现的位置 deep[tot]=dep; //记录深度 for(int i=head[u];i!=-1;i=edge[i].next) { int v=edge[i].to; if(!vis[v]) { dfs(v,dep+1); ver[++tot]=u; deep[tot]=dep; } } } void ST(int n) { for(int i=1;i<=n;i++) //初始化 dp[i][0]=i; for(int j=1;(1<<j)<=n;j++) for(int i=1;i+(1<<j)-1<=n;i++) { int a=dp[i][j-1];int b=dp[i+(1<<(j-1))][j-1]; //记录其中结点序列深度最小的结点的编号 dp[i][j]=deep[a]<deep[b]?a:b; } } int RMQ(int l,int r) { int k=0; while(1<<(k+1)<=r-l+1) //求区间长度以二为底的对数 k++; int a=dp[l][k],b=dp[r-(1<<k)+1][k]; return deep[a]<deep[b]?a:b; } int LCA(int u,int v) { int x=first[u],y=first[v]; if(x>y)swap(x,y); int res=RMQ(x,y); return ver[res]; } void init() { memset(head,-1,sizeof(head)),cnt=0;tot=0; memset(isroot,true,sizeof(isroot)); memset(vis,false,sizeof(vis)); memset(dp,0,sizeof(dp)); memset(deep,0,sizeof(deep)); memset(first,0,sizeof(first)); memset(ver,0,sizeof(ver)); } int main() { scanf("%d",&T); while(T--) { scanf("%d",&n); init(); for(int i=1;i<n;i++) { scanf("%d%d",&a,&b); isroot[b]=false; add(a,b); add(b,a); } int root; for(int i=1;i<=n;i++) { if(isroot[i]) { root=i;break; } } dfs(root,1); ST(2*n-1); scanf("%d%d",&a,&b); printf("%d\n",LCA(a,b)); } return 0; }

以上转自 https://www.cnblogs.com/LjwCarrot/p/9971798.html

最新回复(0)