Alioth_ 的博客

Alioth_ 的博客

[转载]重链剖分&长链剖分

posted on 2019-06-28 21:36:05 | under 知识点 |

这里

长链剖分就是把重儿子换成了长儿子 即最深的点 求的过程如下

void dfs1(int x)
    {
        for(int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x])continue;
            fa[y]=x;
            maxdep[y]=dep[y]=dep[x]+1;
            dfs1(y);
            maxdep[x]=max(maxdep[y],maxdep[x]);
            if(maxdep[y]>maxdep[son[x]])son[x]=y;
        }
    }

一些性质:

1、任意点的任意祖先所在长链长度一定大于等于这个点所在长链长度

2、所有长链长度之和就是总节点数

3、一个点到根的路径上经过的短边最多有$\sqrt{n}$条

然后长链剖分主要是优化$DP$的 可以把一维是深度的DP优化到$O(n)$

这种$DP$继承时一般都是移一位 加一个数($f[i][j]=f[son[i]][j+1]+val$)

过程类似于树上启发式合并 长儿子的答案$O(1)$继承,其他儿子的答案暴力合并到长儿子上 因为所有链长之和为$n$所以总复杂度为$O(n)$

主要有两种实现形式

一. 指针

主要适用于DP一个状态对应一个状态转移 即一个深度对应另一个深度

具体就是把$f[i][j]$($i$是节点 $j$是深度)换成一个指针数组$f$继承时直接让长儿子的指针指向自己的对应位置,$O(1)$的实现合并 然后对于其他儿子每人分配一段内存空间并递归下去,再合并上来

由于总长度为$n$所以$f$数组的大小为$O(n)$级别

具体看代码 CF1009F Dominant Indices

#include<bits/stdc++.h>
using namespace std;
const int maxn=1001000;

namespace Long_Chain_Partitioning{
    struct edge{
        int to,nxt;
    }e[maxn<<1];
    int head[maxn],tot;
    void add(int x,int y){
        e[++tot].to=y;
        e[tot].nxt=head[x];
        head[x]=tot;
    }
    int fa[maxn],dep[maxn],son[maxn],maxdep[maxn];
    void dfs1(int x)
    {
        for(int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x])continue;
            maxdep[y]=dep[y]=dep[x]+1;fa[y]=x;
            dfs1(y);
            maxdep[x]=max(maxdep[x],maxdep[y]);
            if(maxdep[y]>maxdep[son[x]])son[x]=y;
        }
    }
}
using namespace std;
using namespace Long_Chain_Partitioning;

int *f[maxn],ans[maxn],tmp[maxn],*id=tmp,n;

void dfs(int x)
{
    f[x][0]=1;//下面f[son[x]]=f[x]+1使f[x]继承时直接用指针将儿子的数组后移一位
    if(son[x])f[son[x]]=f[x]+1,dfs(son[x]),ans[x]=ans[son[x]]+1;//直接继承重儿子的答案
    for(int i=head[x];i;i=e[i].nxt) 
    {
        int y=e[i].to;
        if(y==fa[x]||y==son[x])continue;
        f[y]=id;//分配子树的内存 
        id+=maxdep[y]-dep[y]+1;//所有长链的长度之和为n 故为O(n) 
        dfs(y);
        for(int j=0;j<=maxdep[y]-dep[y];j++)//遍历给y分配的一段内存 暴力合并轻儿子 
        {
            f[x][j+1]+=f[y][j];//通过DP更新 
            if(f[x][j+1]>f[x][ans[x]]||(f[x][j+1]==f[x][ans[x]]&&j+1<ans[x]))
                ans[x]=j+1;
        }   
    }
    if(f[x][ans[x]]==1)ans[x]=0;//f[x][0]和ans同时满足条件且更小 
}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);add(y,x);
    }
    dep[1]=1;
    dfs1(1);
    f[1]=id;id+=maxdep[1];//这里要注意分配第一段内存
    dfs(1);
    for(int i=1;i<=n;i++)cout<<ans[i]<<"\n";
}

二.长链剖分序与线段树

主要适用于一个深度对应一个区间的转移 即求某一深度区间的信息时

按照长链剖分序建树 每一条长链都是一段区间 然后建树

找答案时 因为所有子树的答案都已经合并到了长链上所以之后查询每个点时只要查询线段树上这个点往下的长链上的信息就是整棵子树中的信息了。

然后把轻儿子的答案合并到长链上即可

P4292 [WC2010]重建计划

// luogu-judger-enable-o2
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize(Ofast)
%:pragma GCC optimize("inline")
#include<bits/stdc++.h>
#define re register
using namespace std;
const int maxn=100100;
const double inf=1e15;
const double eps=1e-4;

namespace IO{
    const int BUF=65536;
    char buf[BUF+1];
    char *hed=buf,*tail=buf;
    inline char gc()
    {
        if(hed==tail)*(tail=(hed=buf)+fread(buf,1,BUF,stdin))=0;
        return *hed++;
    }
    template<typename T>
    void read(T &a)
    {
        a=0;int f=1;
        char ch=gc();
        while(ch<'0'||ch>'9'){
            if(ch=='-')f=-1;
            ch=gc();
        }
        while(ch>='0'&&ch<='9'){
            a=a*10+ch-'0';
            ch=gc();
        }
        a*=f;
    }
}

namespace Long_Chain_Partitioning{
    int tot,m,head[maxn],dfn[maxn],clk,son[maxn],dep[maxn],maxdep[maxn],fa[maxn];
    struct edge{
        int to,nxt;
        double dis,base;
    }e[maxn<<1];
    void add(int x,int y,double z)
    {
        e[++tot].to=y;
        e[tot].dis=z;
        e[tot].base=z;
        e[tot].nxt=head[x];
        head[x]=tot;
        m=tot;
    }
    void dfs1(int x)
    {
        for(re int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x])continue;
            fa[y]=x;
            maxdep[y]=dep[y]=dep[x]+1;
            dfs1(y);
            maxdep[x]=max(maxdep[y],maxdep[x]);
            if(maxdep[y]>maxdep[son[x]])son[x]=y;
        }
    }
    void dfs2(int x) 
    {
        dfn[x]=++clk;
        if(son[x])dfs2(son[x]);
        for(re int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x]||y==son[x])continue;
            dfs2(y);
        }
    }
}

namespace Segment_tree{
    #define ls l,mid,rt<<1
    #define rs mid+1,r,rt<<1|1
    double maxv[maxn<<2];
    void pushup(int rt){
        maxv[rt]=max(maxv[rt<<1],maxv[rt<<1|1]);
    }
    void build(int l,int r,int rt)
    {
        maxv[rt]=-inf;
        if(l==r)return ;
        int mid=l+r>>1;
        build(ls);
        build(rs);
        pushup(rt);
    }
    void update(int pos,double val,int l,int r,int rt)
    {
        if(l==r){
            maxv[rt]=max(val,maxv[rt]);return ;
        }
        int mid=l+r>>1;
        if(pos<=mid)update(pos,val,ls);
        else update(pos,val,rs);
        pushup(rt);
    }
    double query1(int L,int R,int l,int r,int rt)//区间max 
    {
        if(L>R)return -inf;
        if(L<=l&&r<=R)return maxv[rt];
        double ret=-inf,mid=l+r>>1;
        if(L<=mid)ret=max(ret,query1(L,R,ls));
        if(R>mid)ret=max(ret,query1(L,R,rs));
        return ret;
    }
    double query2(int pos,int l,int r,int rt)
    {
        if(l==r)return maxv[rt];
        int mid=l+r>>1;
        if(pos<=mid)return query2(pos,ls);
        else return query2(pos,rs);
    }
}

using namespace Long_Chain_Partitioning;
using namespace Segment_tree;
using namespace IO;

int L,U,n;
double dis[maxn],ans,val[maxn];

void change(double mid){
    for(re int i=1;i<=m;++i)e[i].dis=e[i].base-mid;
    memset(dis,0,sizeof(dis));
    ans=-inf;
    build(1,n,1);
}

void dp(int x)
{
    update(dfn[x],dis[x],1,n,1);
    for(re int i=head[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==son[x]){
            dis[y]=dis[x]+e[i].dis;
            dp(y);break;
        }
    }
    for(re int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa[x]||y==son[x])continue;
        dis[y]=dis[x]+e[i].dis;
        dp(y);
        for(re int j=0;j<=maxdep[y]-dep[y];++j)
        {
            val[j]=query2(dfn[y]+j,1,n,1);//因为子树中的答案都记录在长链上 而且重链上建有线段树 所以可以直接查询 
            if(j<U)ans=max(ans,query1(max(dfn[x]+L-j-1,dfn[x]),min(dfn[x]+U-j-1,dfn[x]+maxdep[x]-dep[x]),1,n,1)+val[j]-dis[x]*2);//区间查询更新答案 
        }
        for(re int j=0;j<=maxdep[y]-dep[y];++j)//子树中的信息合并到x的长链上
            update(dfn[x]+j+1,val[j],1,n,1); 
    }
    ans=max(ans,query1(dfn[x]+L,min(dfn[x]+maxdep[x]-dep[x],dfn[x]+U),1,n,1)-dis[x]);//不经过x折返 在下子树中的最长链 
}

int main()
{
    read(n);read(L);read(U);
    for(re int i=1;i<n;++i)
    {
        int x,y,z;
        read(x);read(y);read(z);
        add(x,y,z);add(y,x,z);
    }
    dfs1(1);
    dfs2(1);
    double l=0,r=2000000;
    while(l+eps<r)
    {
        double mid=(l+r)/2;
        change(mid);
        dp(1);
        if(ans<0)r=mid;
        else l=mid;
    }
    printf("%.3lf",l);
}

还有一道题 给一棵树 求树上距离在$l\rightarrow r$之间的连通块数量之和除以总数量的值 可以想到差分 每次求树上$\le lim$的连通块的数量

用$DP$求 $f[x][d]$表示以$x$为最高点 在$x$子树中最大深度为$d$的连通块的个数

这样转移时十分自然 就是$f[x][max(d_1,d_2+1)]=\sum\limits _{d_1+d_2\le lim}f[x][d_1]\times f[y][d_2]$

最后答案就是每次$d_1+d_2\le lim$的数量之积之和

发现有一个$max$不好转移 就把$max$分类讨论一下

$f[x][d_1]+=\sum\limits_{d_1+d_2+1\le lim,d_1<d_2}f[x][d_1] \times f[y][d_2]$

$f[x][d_2+1]+=(\sum\limits_{d_1+d_2+1\le lim,d_1\le d_2}f[x][d_1])\times f[y][d_2]$

然后因为这个$DP$和深度有关 可以用长链剖分优化到$\Theta(n)$ 我们建出长链剖分上的线段树 然后发现第二个式子就是区间求和 再乘一下加上去就行了 第二个式子就是区间乘一个数 然后我们可以在合并一个新的子树时用一个$map$存一下要乘的区间 再用线段树处理一下就行了

#include<bits/stdc++.h>

using namespace std;

const int maxn=200000+999;
const int p=998244353;

namespace IO{
    const int BUF=65536;
    char buf[BUF+1];
    char *Head=buf,*tail=buf;
    inline char gc()
    {
        if(Head==tail)*(tail=(Head=buf)+fread(buf,1,BUF,stdin))=0;
        return *Head++;
    }
    template<typename T>
    void read(T &a)
    {
        a=0;int f=1;
        char ch=gc();
        while(ch<'0'||ch>'9'){
            if(ch=='-')f=-1;
            ch=gc();
        }
        while(ch>='0'&&ch<='9'){
            a=a*10+ch-'0';
            ch=gc();
        }
        a*=f;
    }
}

using namespace IO;

namespace Segment_tree{
    #define ls l,m,rt<<1
    #define rs m+1,r,rt<<1|1
    typedef long long ll;
    ll poww(ll a,ll b){ll ans=1;while(b){if(b&1)ans=ans%p*a%p;a=a%p*a%p;b>>=1;};return ans;}
    ll sumv[maxn<<2],addv[maxn<<2],mul[maxn<<2];
    inline void pushup(int rt){sumv[rt]=sumv[rt<<1]+sumv[rt<<1|1];}
    void pushdown(int rt,int l)
    {
        sumv[rt<<1]=(sumv[rt<<1]*mul[rt]%p+(addv[rt]*(l-(l>>1))%p))%p;
        sumv[rt<<1|1]=(sumv[rt<<1|1]*mul[rt]%p+(addv[rt]*(l>>1)))%p;
        mul[rt<<1]=mul[rt<<1]*mul[rt]%p;
        mul[rt<<1|1]=mul[rt<<1|1]*mul[rt]%p;
        addv[rt<<1]=((addv[rt<<1]*mul[rt]%p)+addv[rt])%p;
        addv[rt<<1|1]=((addv[rt<<1|1]*mul[rt]%p)+addv[rt])%p;
        addv[rt]=0;mul[rt]=1;
    }
    void build(int l,int r,int rt)
    {
        mul[rt]=1;sumv[rt]=0;
        if(l==r){return;}
        int m=(l+r)>>1;
        build(ls);
        build(rs);
    }
    void add(int L,int R,ll c,int l,int r,int rt)
    {
        if(L<=l&&R>=r){sumv[rt]=(sumv[rt]+c*(r-l+1))%p;addv[rt]=(addv[rt]+c)%p;return;}
        int m=(l+r)>>1;
        pushdown(rt,r-l+1);
        if(L<=m)add(L,R,c,ls);
        if(R>m)add(L,R,c,rs);
        pushup(rt);
    }
    void multi(int L,int R,int c,int l,int r,int rt)
    {
        if(L<=l&&R>=r){
            sumv[rt]=sumv[rt]*c%p;addv[rt]=addv[rt]*c%p;mul[rt]=mul[rt]*c%p;
            return;
        }
        pushdown(rt,r-l+1);
        int m=(l+r)>>1;
        if(L<=m)multi(L,R,c,ls);
        if(R>m)multi(L,R,c,rs);
        pushup(rt);
    }
    int query(int L,int R,int l,int r,int rt)
    {
        if(L<=l&&R>=r)return sumv[rt]%p;
        pushdown(rt,r-l+1);
        int m=(l+r)>>1;
        int tot=0;
        if(L<=m)tot=(tot+query(L,R,ls))%p;
        if(R>m)tot=(tot+query(L,R,rs))%p;
        pushup(rt);
        return tot;
    }
}

using namespace Segment_tree;

struct edge{
    int to,nxt;
}e[maxn<<1];
int head[maxn],tot;
void add(int x,int y){
    e[++tot].to=y;
    e[tot].nxt=head[x];
    head[x]=tot;
}

int ans,dep[maxn],maxdep[maxn],son[maxn],fa,dfn[maxn],clk,lim,n,l,r;

void dfs1(int x,int fa)
{
    if(x==1)dep[x]=1;
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==fa)continue;
        dep[y]=maxdep[y]=dep[x]+1;
        dfs1(y,x);
        maxdep[x]=max(maxdep[y],maxdep[x]);
        if(maxdep[y]>maxdep[son[x]])son[x]=y;
    }
}

void dfs2(int x,int fa)
{
    dfn[x]=++clk;
    if(son[x])dfs2(son[x],x);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to;
        if(y==son[x]||y==fa)continue;
        dfs2(y,x);  
    } 
}

void dp(int x,int fa){
    static map<int,int>mp;
    static ll up[maxn],f[maxn];
    add(dfn[x],dfn[x],1,1,n,1);
    if(son[x])dp(son[x],x);
    for(int i=head[x];i;i=e[i].nxt)
    {
        int y=e[i].to,sum=0;
        if(y==fa||y==son[x])continue;
        dp(y,x);
        mp.clear();
        mp[0]=1;mp[maxdep[x]-dep[x]+1]=0;
        for(int j=0;j<lim&&j<=maxdep[y]-dep[y];j++){//第二个式子的预处理 
            up[j]=query(dfn[x],dfn[x]+min(lim-j-1,j),1,n,1);
            f[j]=query(dfn[y]+j,dfn[y]+j,1,n,1);
        }
        for(int j=0;j<=maxdep[y]-dep[y]&&j<lim;j++)//第一个式子 
        if(j+j+1<lim)
        {
            (mp[j+1]+=f[j])%=p;
            (mp[min(lim-j,maxdep[x]-dep[x]+1)]+=p-f[j])%=p;
        }
        for(map<int,int>::iterator it=mp.begin(),it2=++mp.begin();it2!=mp.end();it++,it2++)
        {
            sum=(sum+it->second)%p;
            multi(dfn[x]+it->first,dfn[x]+it2->first-1,sum,1,n,1);
        }
        for(int j=0;j<=maxdep[y]-dep[y]&&j<lim;j++)
            add(dfn[x]+j+1,dfn[x]+j+1,1ll*up[j]*f[j]%p,1,n,1);//第二个式子 
    }
    ans=(ans+query(dfn[x],dfn[x]+min(lim,maxdep[x]-dep[x]),1,n,1))%p;
}

ll work(int k){
    lim=k;ans=0;build(1,n,1);
    dp(1,0);
    return ans;
}

int main()
{
    freopen("tree.in","r",stdin);
    freopen("tree.out","w",stdout);
    read(n);read(l);read(r);
    for(int i=1;i<n;i++){
        int x,y;
        read(x);read(y);
        add(x,y);add(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    cout<<1ll*poww(work(n),p-2)*(work(r)-work(l-1)+p)%p;
}

还有一道和优化$DP$没有什么关系的用长链性质的题 BZOJ3252 攻略 代码也放上吧

#include<bits/stdc++.h>
#define int long long
const int maxn=200300;

using namespace std;

namespace Long_Chain_Partitioning{
    int n,k,num=1,val[maxn],w[maxn],dep[maxn],maxdep[maxn],son[maxn],head[maxn],fa[maxn],tot;
    struct edge{
        int to,nxt;
    }e[maxn<<1];
    void add(int x,int y)
    {
        e[++tot].to=y;
        e[tot].nxt=head[x];
        head[x]=tot;
    }
    void dfs1(int x)
    {
        for(int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x])continue;
            dep[y]=maxdep[y]=dep[x]+w[y];
            fa[y]=x;
            dfs1(y);
            maxdep[x]=max(maxdep[x],maxdep[y]);
            if(maxdep[y]>maxdep[son[x]])son[x]=y;
        }
    }
    void dfs2(int x)
    {
        val[num]+=w[x];
        if(son[x])dfs2(son[x]);
        for(int i=head[x];i;i=e[i].nxt)
        {
            int y=e[i].to;
            if(y==fa[x]||y==son[x])continue;
            ++num;dfs2(y);
        }
    }
}

using namespace Long_Chain_Partitioning;

signed main()
{
    scanf("%lld%lld",&n,&k);
    for(int i=1;i<=n;i++)scanf("%lld",&w[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%lld%lld",&x,&y);
        add(x,y);
        add(y,x);
    }
    dfs1(1);
    dfs2(1);
    sort(val,val+1+num);
    int ans=0;
    for(int i=num;i>=num-k+1;i--)
        ans+=val[i];
    cout<<ans;
}
//5 2
//4 3 2 1 1
//1 2
//1 5
//2 3
//2 4

好像还可以$O(1)$求一个点的$k$级祖先 不过应该没有什么屌用