[树链剖分][SDOI2011]染色

发布于 2018-09-14  450 次阅读


题目大意

给定一棵有n个节点的无根树和m个操作,操作有2类:
1、将节点a到节点b路径上所有点都染成颜色c;
2、询问节点a到节点b路径上的颜色段数量(连续相同颜色被认为是同一段),如“112221”由3段组成:“11”、“222”和“1”。
请你写一个程序依次完成这m个操作。

我太菜了,写数据结构都会re。
一道裸的树链剖分,考点不是树链剖分,大概是线段树合并。

## 题解:

首先一眼看出是树剖裸题,没有任何奇技淫巧。
发现线段树合并需要考虑。
考虑合并两个区间会发生的事情,
设区间1[l1,r1],区间2[l2,r2];
如果合并时r1和l2的颜色一样,显然颜色数要-1.
否则直接合并。
所以要记录的值就知道了
区间右端点颜色right_color,区间左端点颜色left_color,
区间颜色段数val.
发现会有这种情况(又要放图片,没内存了)!

假装红色的是重链,其他的是轻链,有一个询问问你紫点到蓝点的颜色数,根据树剖的写法,我们会先算两个轻链在算重链,这时我们就需要记轻链和重链的连接处的颜色,判断答案是否-1,我们可以记录每个点的左右儿子颜色,开一个结构体,类似动态开点的写法,但我刚开始没考虑到,写了一半才发现,就只好$log(n)$来找颜色了。

写完后和kingsann讨论(就是友链里的),发现他暴力判lca什么的,详细可以私信他。

#include <cstdio>
#include <algorithm>

#define ls (p<<1)
#define rs (p<<1|1)
using namespace std;
const int N=1e6+5;

inline int read() { char k=0;char sb;sb=getchar(); for(;sb<'0'||sb>'9';k=sb,sb=getchar()); int x=0;for(;sb>='0'&&sb<='9';sb=getchar())x=x*10+sb-'0'; if(k=='-')x=0-x;return x; }
struct data{int v,nxt;}edge[N<<1];
int alist[N],cnt;
inline void add(int u,int v){edge[++cnt]=(data){v,alist[u]},alist[u]=cnt;}
int col[N];
int n,m;

int siz[N],dep[N],fa[N],son[N],top[N],dfn[N],num[N],tot;
inline void dfs1(int x)
{
    siz[x]=1;dep[x]=dep[fa[x]]+1;
    for(int i=alist[x];i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(v==fa[x]) continue;
        fa[v]=x;dfs1(v);siz[x]+=siz[v];
        if(son[x]==-1||siz[v]>siz[son[x]]) son[x]=v;
    }
}
inline void dfs2(int x)
{
    dfn[x]=++tot;num[tot]=x;
    if(!top[x]) top[x]=x;
    if(son[x]) top[son[x]]=top[x],dfs2(son[x]);
    for(int i=alist[x];i;i=edge[i].nxt)
    {
        int v=edge[i].v;
        if(v==son[x]||v==fa[x]) continue;
        dfs2(v);
    }
}

int color_right[N<<2],color_left[N<<2],tag[N<<2],val[N<<2];
int las_rcol,las_lcol,las_col;
inline void pushdown(int p)
{
    if(tag[p]==-1) return;
    tag[ls]=tag[rs]=tag[p];
    val[ls]=val[rs]=1;
    color_right[ls]=color_left[ls]=color_right[rs]=color_left[rs]=tag[p];
    tag[p]=-1;

}
inline void updata(int p)
{
    color_right[p]=color_right[rs];color_left[p]=color_left[ls];
    val[p]=val[ls]+val[rs]-(color_right[ls]==color_left[rs]);
}
inline void build(int p,int l,int r)
{
    tag[p]=-1;
    if(l==r)
    {
        val[p]=1;
        color_right[p]=color_left[p]=col[num[l]];
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);build(rs,mid+1,r);
    updata(p);
}
inline int query(int p,int l,int r,int dl,int dr)
{
    if(dl<=l&&dr>=r)
    {
        return val[p];
    }
    pushdown(p);
    int mid=(l+r)>>1;
    int res=0;
    if(dl<=mid) res+=query(ls,l,mid,dl,dr);
    if(dr>mid) res+=query(rs,mid+1,r,dl,dr);
    if(dl<=mid&&dr>mid&&color_right[ls]==color_left[rs]) res--;
    updata(p);
    return res;
}
inline void setchange(int p,int l,int r,int dl,int dr,int z)
{
    if(dl<=l&&dr>=r)
    {
        tag[p]=color_right[p]=color_left[p]=z;
        val[p]=1;
        return ;
    }
    pushdown(p);
    int mid=(l+r)>>1;
    if(dl<=mid) setchange(ls,l,mid,dl,dr,z);
    if(dr>mid) setchange(rs,mid+1,r,dl,dr,z);
    updata(p);
}
inline void be()
{
    dfs1(1),dfs2(1);build(1,1,n);
}
inline int find(int p,int l,int r,int pos)
{
    if(l==r) return color_left[p];
    pushdown(p);
    int mid=(l+r)>>1;
    if(pos<=mid) return find(ls,l,mid,pos);
    return find(rs,mid+1,r,pos);
}
inline int ask(int x,int y)
{
    int res=0;
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        res+=query(1,1,n,dfn[top[x]],dfn[x]);
        int c1=find(1,1,n,dfn[top[x]]);
        int c2=find(1,1,n,dfn[fa[top[x]]]);
        if(c1==c2) res--;
        x=fa[top[x]];
    }
    las_col=0;
    if(dep[y]<dep[x]) swap(x,y);
    res+=query(1,1,n,dfn[x],dfn[y]);
    return res;
}
inline void change(int x,int y,int z)
{
    while(top[x]!=top[y])
    {
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        setchange(1,1,n,dfn[top[x]],dfn[x],z);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    setchange(1,1,n,dfn[x],dfn[y],z);
}
signed main()
{
//  freopen("2.in","r",stdin);
//  freopen("1.out","w",stdout);
    n=read();m=read();
    for(int i=1;i<=n;++i) col[i]=read();
    for(int i=1,x,y;i<n;++i)
    {
        x=read();y=read();
        add(x,y),add(y,x);
    }
    be();
    for(int i=1;i<=m;++i)
    {
        char c[10];int x,y,z;
        scanf("%s",&c);
        if(c[0]=='Q')
        {
            x=read();y=read();
            printf("%d\n",ask(x,y));
        }
        else 
        {
            x=read();y=read();z=read();
            change(x,y,z);
        }
    }
}
0.00 avg. rating (0% score) - 0 votes

一沙一世界,一花一天堂。君掌盛无边,刹那成永恒。