提供四种复杂度的做法,希望能帮到大家
\(O\)(\(n^3\)):
#include<bits/stdc++.h> using namespace std; #define go(i,a,b) for(int i=a;i<=b;++i) #define com(i,a,b) for(int i=a;i>=b;--i) #define mem(a,b) memset(a,b,sizeof(a)) typedef long long ll; const int N=300+10,inf=0x3f3f3f3f; int n,s,head[N],d[N],dis[N],f[N],ans=0,ans2=inf,cnt=0; bool vis[N]; vector<int>lng,cp; struct edge{ int nxt,v,w; }e[N*2]; void add(int u,int v,int w){ e[cnt]=(edge){head[u],v,w}; head[u]=cnt++; } void dfs(int u,int fa){ for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; d[v]=d[u]+w; dfs(v,u); } } void dfs2(int u,int fa){ f[u]=fa; for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; dis[v]=dis[u]+w; dfs2(v,u); } } void dfs3(int u,int fake){ vis[u]=1; ans=max(ans,fake); for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(vis[v]) continue; dfs3(v,fake+w); } } int main(){ //freopen("input.txt","r",stdin); mem(head,-1); cin>>n>>s; int x,y,w; go(i,1,n-1){ cin>>x>>y>>w; add(x,y,w); add(y,x,w); } dfs(1,0); int Max=-1,id,id2; go(i,1,n) if(d[i]>Max){ Max=d[i],id=i; } dfs2(id,0); Max=-1; go(i,1,n) if(dis[i]>Max){ Max=dis[i],id2=i; } x=id2; while(x!=id){ lng.push_back(x); x=f[x]; } lng.push_back(id); int l=lng.size()-1; go(i,0,l) go(j,i,l){ mem(vis,0); x=lng[i],y=lng[j]; if(dis[x]-dis[y]>s) continue; cp.clear(); while(x!=y) cp.push_back(x),x=f[x],vis[x]=1; cp.push_back(y),vis[y]=1; ans=0; for(int k=0;k<cp.size();++k) dfs3(cp[k],0); ans2=min(ans,ans2); } cout<<ans2; return 0; }
\(O\)(\(n^2\)):
#include<bits/stdc++.h> using namespace std; #define go(i,a,b) for(int i=a;i<=b;++i) #define com(i,a,b) for(int i=a;i>=b;--i) #define mem(a,b) memset(a,b,sizeof(a)) typedef long long ll; const int N=300+10,inf=0x3f3f3f3f; int n,s,head[N],d[N],dis[N],f[N],ans=0,far[N],ans2=inf,cnt=0; bool vis[N]; vector<int>lng,cp; struct edge{ int nxt,v,w; }e[N*2]; void add(int u,int v,int w){ e[cnt]=(edge){head[u],v,w}; head[u]=cnt++; } void dfs(int u,int fa){ for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; d[v]=d[u]+w; dfs(v,u); } } void dfs2(int u,int fa){ f[u]=fa; for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; dis[v]=dis[u]+w; dfs2(v,u); } } void dfs3(int u,int fake){ vis[u]=1; ans=max(ans,fake); for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(vis[v]) continue; dfs3(v,fake+w); } } int main(){ //freopen("input.txt","r",stdin); mem(head,-1); cin>>n>>s; int x,y,w; go(i,1,n-1){ cin>>x>>y>>w; add(x,y,w); add(y,x,w); } dfs(1,0); int Max=-1,id,id2; go(i,1,n) if(d[i]>Max){ Max=d[i],id=i; } dfs2(id,0); Max=-1; go(i,1,n) if(dis[i]>Max){ Max=dis[i],id2=i; } x=id2; while(x!=id){ lng.push_back(x); x=f[x]; } lng.push_back(id); int l=lng.size()-1; go(i,0,l){ x=lng[i]; go(j,i,l){ y=lng[j]; if(dis[x]-dis[y]<=s) far[x]=y; else break; } } go(i,0,l){ mem(vis,0); x=lng[i],y=far[x]; if(dis[x]-dis[y]>s) continue; cp.clear(); while(x!=y) cp.push_back(x),x=f[x],vis[x]=1; cp.push_back(y),vis[y]=1; ans=0; for(int k=0;k<cp.size();++k) dfs3(cp[k],0); ans2=min(ans,ans2); } cout<<ans2; return 0; }
\(O\)(\(n*log(sum)\)):
#include<bits/stdc++.h> using namespace std; #define go(i,a,b) for(int i=a;i<=b;++i) #define com(i,a,b) for(int i=a;i>=b;--i) #define mem(a,b) memset(a,b,sizeof(a)) typedef long long ll; const int N=300+10,inf=0x3f3f3f3f; int n,s,head[N],d[N],dis[N],f[N],ans=0,far[N],ans2=inf,cnt=0; bool vis[N]; vector<int>lng,cp; struct edge{ int nxt,v,w; }e[N*2]; void add(int u,int v,int w){ e[cnt]=(edge){head[u],v,w}; head[u]=cnt++; } void dfs(int u,int fa){ for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; d[v]=d[u]+w; dfs(v,u); } } void dfs2(int u,int fa){ f[u]=fa; for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; dis[v]=dis[u]+w; dfs2(v,u); } } void dfs3(int u,int fake){ vis[u]=1; ans=max(ans,fake); for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(vis[v]) continue; dfs3(v,fake+w); } } bool pd(int x,int y,int Max){ if(dis[x]<dis[y]) return 1; if(dis[x]-dis[y]>s) return 0; mem(vis,0); int a=x; while(x!=y) vis[x]=1,x=f[x]; vis[y]=1; ans=0; x=a; while(x!=y){ x=f[x]; if(x!=y) dfs3(x,0); } return ans<=Max; } int main(){ //freopen("input.txt","r",stdin); mem(head,-1); cin>>n>>s; int x,y,w; ll L=0,R=0; go(i,1,n-1){ cin>>x>>y>>w; R+=w; add(x,y,w); add(y,x,w); } dfs(1,0); int Max=-1,id,id2; go(i,1,n) if(d[i]>Max){ Max=d[i],id=i; } dfs2(id,0); Max=-1; go(i,1,n) if(dis[i]>Max){ Max=dis[i],id2=i; } x=id2; while(x!=id){ lng.push_back(x); x=f[x]; } lng.push_back(id); int l=lng.size()-1; while(L<=R){ ll mid=L+R>>1; int a,b; y=lng[0]; go(i,0,l){ x=lng[i]; if(dis[y]-dis[x]<=mid) a=x; else break; } y=lng[l]; com(i,l,0){ x=lng[i]; if(dis[x]-dis[y]<=mid) b=x; } if(pd(a,b,mid)) R=mid-1; else L=mid+1; } cout<<R+1; return 0; }
\(O\)(\(n\)):
#include<bits/stdc++.h> using namespace std; #define go(i,a,b) for(int i=a;i<=b;++i) #define com(i,a,b) for(int i=a;i>=b;--i) #define mem(a,b) memset(a,b,sizeof(a)) typedef long long ll; const int N=300+10,inf=0x3f3f3f3f; int n,s,head[N],d[N],dis[N],f[N],ed[N],ans1=0,cnt=0,ans=inf; bool vis[N]; vector<int>lng; struct edge{ int nxt,v,w; }e[N*2]; void add(int u,int v,int w){ e[cnt]=(edge){head[u],v,w}; head[u]=cnt++; } void dfs(int u,int fa){ for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; d[v]=d[u]+w; dfs(v,u); } } void dfs2(int u,int fa,int edg){ f[u]=fa,ed[u]=edg; for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(v==fa) continue; dis[v]=dis[u]+w; dfs2(v,u,w); } } void dfs3(int u,int fake){ vis[u]=1; ans1=max(ans1,fake); for(int i=head[u];i+1;i=e[i].nxt){ int v=e[i].v,w=e[i].w; if(vis[v]) continue; dfs3(v,fake+w); } } int main(){ //freopen("input.txt","r",stdin); mem(head,-1); cin>>n>>s; int x,y,w; go(i,1,n-1){ cin>>x>>y>>w; add(x,y,w),add(y,x,w); } dfs(1,0); int Max=-1,id,id2; go(i,1,n) if(d[i]>Max){ Max=d[i],id=i; } dfs2(id,0,0); Max=-1; go(i,1,n) if(dis[i]>Max){ Max=dis[i],id2=i; } x=id2; while(x!=id){ lng.push_back(x); vis[x]=1,x=f[x]; } lng.push_back(id); vis[id]=1; int l=lng.size()-1; go(i,0,l) dfs3(lng[i],0); y=0; go(i,0,l){ x=lng[i]; d[x]=y; y+=ed[x]; } for(int i=0,j=0;i<=l;++i){ x=lng[i]; while(j<i&&d[x]-d[lng[j]]>s) ++j; y=lng[j]; ans=min(ans,max(ans1,max(d[y],d[lng[l]]-d[x]))); } cout<<ans; return 0; }
还有一种单调队列的O(n)做法,我没有写代码,这里推荐Sparky_14145的题解