【题目描述】:
给定一棵无根树,边权都是1,请去掉一条边并加上一条新边,定义直径为最远的两个点的距离,请输出所有可能的新树的直径的最小值和最大值
【题目解法】:
poi2015最后一坑终于被我填上啦!
为啥波兰人那么喜欢方案...
考虑枚举删掉的边
那么最大值显然就是两颗树的直径拼起来
最小值是两棵树的直径中点拼起来
证明:显然对于一棵树,它能贡献的最大值的最小值一定是直径的一半,如果他向下扩展最大值都不到直径的一半,那么我们找到两条最深的链,把他们拼起来就是直径,然后长度是小于直径的,因此矛盾,所以中点拼起来已经是最优的
显然考虑求出对于每一条树边,断掉他后上半部分和下半部分的直径
首先定义三元组(x,y,from,val)
如果对应的是直径,那么x y对应直径的端点,from对应来自哪个儿子,val对应长度,
如果对应的是链,那么from对应儿子编号,y对应链的末端编号,val代表长度
考虑我们从下往上dp,统计出:
1.往下走长度前三大的链
2.儿子子树中长度前两大的直径
3.这个点所在子树对应的直径
然后我们从上往下dp,统计出:
1.连着某个点向上走的最长长度,这个可以用fa节点的信息更新
2.某个点断开他与fa的连边后,上半部分的直径,我们分情况考虑
1.不经过fa,显然可以有断开fa和爷爷边后上半部分的直径,以及fa的儿子(不包括该儿子)中的直径更新
2.经过fa,可以有两个儿子中的深度拼起来,或者最大的儿子和该点连着向上的长度更新
最后扫一遍得到答案后,O(N)扫一遍链的中点
复杂度O(N)
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
const int N=500005;
const int M=1000005;
const int INF=32083208;
struct edge {int x,next;} b[M];
struct node {int x,y,val,from;};
int n,tot,a[N],fa[N],dep[N];
inline long long getint()
{
long long x=0; char c=getchar(); bool flag=false;
while ((c!='-')&&((c<'0')||(c>'9'))) c=getchar();
if (c=='-') flag=true,c=getchar();
while ((c>='0')&&(c<='9')) x=x*10+(long long)(c-'0'),c=getchar();
if (flag) return -x; else return x;
}
inline void addedge(int x,int y)
{
++tot;
b[tot].x=y;
b[tot].next=a[x];
a[x]=tot;
}
void init()
{
n=getint();
for (int i=1; i<n; i++)
{
int x=getint(),y=getint();
addedge(x,y); addedge(y,x);
}
}
namespace down
{
struct Node
{
node a1,a2,a3; node b1,b2,ans;
inline void init(int x)
{
a1.x=a1.y=a2.x=a2.y=a3.x=a3.y=x;
b1.x=b1.y=b2.x=b2.y=x; ans.x=ans.y=x;
a1.from=a2.from=a3.from=x;
b1.from=b2.from=ans.from=x;
}
} f[N];
inline void updata_Path(Node &a,node b)
{
if (b.val>=a.b1.val) a.b2=a.b1,a.b1=b;
else if (b.val>=a.b2.val) a.b2=b;
}
inline void updata_Chain(Node &a,node b)
{
if (b.val>=a.a1.val) a.a3=a.a2,a.a2=a.a1,a.a1=b;
else if (b.val>=a.a2.val) a.a3=a.a2,a.a2=b;
else if (b.val>=a.a3.val) a.a3=b;
}
inline void updata(node &a,node b)
{
if (b.val<=a.val) return; a=b;
}
void dfs(int x)
{
f[x].init(x);
for (int p=a[x];p;p=b[p].next)
{
int pp=b[p].x; if (pp==fa[x]) continue; fa[pp]=x; dep[pp]=dep[x]+1; dfs(pp);
updata_Chain(f[x],(node){pp,f[pp].a1.y,f[pp].a1.val+1,pp});
updata_Path(f[x],(node){f[pp].ans.x,f[pp].ans.y,f[pp].ans.val,pp});
}
updata(f[x].ans,f[x].b1);
updata(f[x].ans,(node){f[x].a1.y,f[x].a2.y,f[x].a1.val+f[x].a2.val,0});
}
void solve()
{
dep[1]=1; dfs(1);
}
}
namespace up
{
struct Node
{
node a,b;
inline void init(int x)
{
a.x=a.y=x; a.val=0;
b.x=b.y=x; b.val=0;
}
} f[N];
inline void updata(node &a,node b)
{
if (b.val<a.val) return; a=b;
}
inline void updata(Node &a,node b)
{
if (b.val>=a.a.val) a.b=a.a,a.a=b;
else if (b.val>=a.b.val) a.b=b;
}
void dfs(int x)
{
f[x].init(x); int fa1=fa[x];
if (fa1!=0)
{
updata(f[x].a,(node){fa1,f[fa1].a.y,f[fa1].a.val+1,0});
if (down::f[fa1].a1.from!=x) updata(f[x].a,(node){fa1,down::f[fa1].a1.y,down::f[fa1].a1.val+1,0});
else updata(f[x].a,(node){fa1,down::f[fa1].a2.y,down::f[fa1].a2.val+1,0});
updata(f[x].b,f[fa1].b);
if (down::f[fa1].b1.from!=x) updata(f[x].b,down::f[fa1].b1); else updata(f[x].b,down::f[fa1].b2);
Node tmp; tmp.init(0); updata(tmp,f[fa1].a);
if (down::f[fa1].a1.from!=x) updata(tmp,down::f[fa1].a1);
if (down::f[fa1].a2.from!=x) updata(tmp,down::f[fa1].a2);
if (down::f[fa1].a3.from!=x) updata(tmp,down::f[fa1].a3);
updata(f[x].b,(node){tmp.a.y,tmp.b.y,tmp.a.val+tmp.b.val,0});
}
for (int p=a[x];p;p=b[p].next) if (b[p].x!=fa[x]) dfs(b[p].x);
}
void solve()
{
dfs(1);
}
}
namespace GetMin
{
int a[N];
int LCA(int x,int y)
{
for (;x!=y;x=fa[x]) if (dep[x]<dep[y]) swap(x,y);
return x;
}
int getmid(int x,int y)
{
int k=LCA(x,y),tot=dep[x]+dep[y]-2*dep[k]+1;
int first=1,last=tot;
for (;x!=k;x=fa[x]) a[first++]=x;
for (;y!=k;y=fa[y]) a[last--]=y;
a[first]=k; return a[(tot+1)/2];
}
void solve()
{
node ans,cut; ans.val=INF; int pos;
for (int i=2; i<=n; i++)
{
node now;
now.x=up::f[i].b.x; now.y=down::f[i].ans.x;
now.val=(up::f[i].b.val+1)/2+(down::f[i].ans.val+1)/2+1;
now.val=max(now.val,max(up::f[i].b.val,down::f[i].ans.val));
if (now.val<ans.val)
{
ans.val=now.val; pos=i;
ans.x=up::f[i].b.x; ans.y=up::f[i].b.y;
cut.x=down::f[i].ans.x; cut.y=down::f[i].ans.y;
}
}
int x=getmid(ans.x,ans.y),y=getmid(cut.x,cut.y);
printf("%d %d %d %d %d\n",ans.val,pos,fa[pos],x,y);
}
}
namespace GetMax
{
void solve()
{
node ans,cut; ans.val=0;
for (int i=2; i<=n; i++)
{
node now;
now.x=up::f[i].b.x; now.y=down::f[i].ans.x;
now.val=up::f[i].b.val+down::f[i].ans.val+1;
if (now.val>ans.val) ans=now,cut.x=i,cut.y=fa[i];
}
printf("%d %d %d %d %d\n",ans.val,cut.x,cut.y,ans.x,ans.y);
}
}
int main()
{
init();
down::solve();
up::solve();
GetMin::solve();
GetMax::solve();
return 0;
}