LCT直接上
记录子树的大小
#include<cstdio> #include<algorithm> using namespace std; int n,m,Next[200005]; struct Node{ int siz,rev; Node *ch[2],*fa; Node(Node *fat); void Pushup(); void Pushdown(); }*root[200005],*Null; Node::Node(Node *fat){ ch[0]=ch[1]=Null; fa=fat; siz=1; rev=0; } void Node::Pushup(){ siz=1; if(ch[0]!=Null)siz+=ch[0]->siz; if(ch[1]!=Null)siz+=ch[1]->siz; } void Node::Pushdown(){ if(rev){ if(ch[0]!=Null)ch[0]->rev^=1; if(ch[1]!=Null)ch[1]->rev^=1; swap(ch[0],ch[1]); rev=0; } } int Notroot(Node *x){ return (x->fa->ch[0]==x)||(x->fa->ch[1]==x); } void Prepare(Node *x){ if(Notroot(x))Prepare(x->fa); x->Pushdown(); } void Rotate(Node *x,int kind){ Node *y=x->fa,*z=y->fa; y->ch[!kind]=x->ch[kind]; if(x->ch[kind]!=Null)x->ch[kind]->fa=y; x->fa=z; if(Notroot(y))z->ch[z->ch[1]==y]=x; y->fa=x; x->ch[kind]=y; y->Pushup(); x->Pushup(); } void Splay(Node *x){ Prepare(x); while(Notroot(x)){ Node *y=x->fa,*z=y->fa; if(!Notroot(y)){Rotate(x,y->ch[0]==x);} else { if(y->ch[1]==x && z->ch[1]==y){Rotate(y,0);Rotate(x,0);} else if(y->ch[1]==x && z->ch[0]==y){Rotate(x,0);Rotate(x,1);} else if(y->ch[0]==x && z->ch[0]==y){Rotate(y,1);Rotate(x,1);} else {Rotate(x,1);Rotate(x,0);} } } } void Access(Node *x){ for(Node *y=Null;x!=Null;y=x,x=x->fa){Splay(x);x->ch[1]=y;x->Pushup();} } void Makeroot(Node *x){ Access(x); Splay(x); x->rev^=1; } void Link(Node *x,Node *y){ Makeroot(x); x->fa=y; } void Cut(Node *x,Node *y){ Makeroot(x); Access(y); Splay(y); if(y->ch[0]==x){y->ch[0]=Null;x->fa=Null;} } Node *Find(Node *x){ Access(x); Splay(x); while(x->ch[0]!=Null)x=x->ch[0]; return x; } Node *ToLct(int x){ return root[x]; } Node *GetNull(){ Node *u=new Node(Null); u->fa=u->ch[0]=u->ch[1]=u; u->siz=0; u->rev=0; return u; } int Jump(int x){ Node *u=ToLct(x); Makeroot(ToLct(n+1)); Access(u); Splay(u); return u->ch[0]->siz; } void Change(int x,int y){ //Node *u=ToLct(x),*v=ToLct(Next[x]); //printf("Tp%d\n",Next[x]); int t=min(x+y,n+1); //printf("Tpt%d\n",t); Cut(ToLct(x),ToLct(Next[x])); Link(ToLct(x),ToLct(t)); Next[x]=t; } int main(){ freopen("2002.in","r",stdin); freopen("2002.out","w",stdout); Null=GetNull(); scanf("%d",&n); root[n+1]=new Node(Null); for(int i=1;i<=n;i++){ int tp; scanf("%d",&tp); root[i]=new Node(Null); Next[i]=min(n+1,tp+i); } for(int i=1;i<=n;i++){ Link(ToLct(i),ToLct(Next[i])); } scanf("%d",&m); for(int i=1;i<=m;i++){ int opt,x,y; scanf("%d",&opt); if(opt==1){scanf("%d",&x);printf("%d\n",Jump(x+1));} if(opt==2){scanf("%d %d",&x,&y);Change(x+1,y);} } return 0; }