AC代码
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=3e2+5; const int inf=0x3f3f3f3f; #define ls (i<<1) #define rs (i<<1|1) #define fi first #define se second LL read() { LL x=0,t=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-')t=-1; ch=getchar(); } while(isdigit(ch)){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*t; } struct edge { int from,to,next; edge(){} edge(int ff,int tt,int nn) { from=ff; to=tt; next=nn; } }; edge e[N<<1]; int f[N][N],head[N],tot,n,m,a[N],num[N]; void add(int from,int to) { e[++tot]=edge(from,to,head[from]); head[from]=tot; } void dfs(int u,int pre) { f[u][1]=a[u]; for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v==pre) continue; dfs(v,u); for(int j=m;j>0;j--) for(int k=j-1;k>=0;k--) ///如果要取儿子,那么自己至少要取 f[u][j]=max(f[u][j],f[u][j-k]+f[v][k]); } } int main() { n=read(); m=read(); for(int i=1;i<=n;i++) { int x=read(); a[i]=read(); add(x,i); } m++;///多了0号点 dfs(0,-1); printf("%d\n",f[0][m]); return 0; }由于加入了0号点,那么除了0以外,每一个点都有一个对应的与父亲相连的边 , 所以也可以将 点权 转化成对应的 边权 (与父亲的连边)
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int N=3e2+5; const int inf=0x3f3f3f3f; #define ls (i<<1) #define rs (i<<1|1) #define fi first #define se second LL read() { LL x=0,t=1; char ch=getchar(); while(!isdigit(ch)){ if(ch=='-')t=-1; ch=getchar(); } while(isdigit(ch)){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); } return x*t; } struct edge { int from,to,next; edge(){} edge(int ff,int tt,int nn) { from=ff; to=tt; next=nn; } }; edge e[N<<1]; int f[N][N],head[N],tot,n,m,a[N],num[N]; void add(int from,int to) { e[++tot]=edge(from,to,head[from]); head[from]=tot; } void dfs(int u,int pre) { num[u]=1; for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v==pre) continue; dfs(v,u); num[u]+=num[v]; } for(int i=head[u];i;i=e[i].next) { int v=e[i].to; if(v==pre) continue; for(int j=min(m,num[u]);j>=0;j--)//(取min是为了优化常数,不是必须的) for(int k=min(j-1,num[v]);k>=0;k--) f[u][j]=max(f[u][j],f[u][j-k-1]+f[v][k]+a[v]); } } int main() { n=read(); m=read(); for(int i=1;i<=n;i++) { int x=read(); a[i]=read(); add(x,i); } dfs(0,-1); printf("%d\n",f[0][m]); return 0; }