BZOJ1036/树的统计

Posted on By 二价氢

看到这个题目,首先想到的便是线段树,然后发现是棵树,傻了QAQ

所以就要树的分治

上面这张图表明了树的分治过程,图上用粗线标记的子结点称“重儿子”,其余的结点称作“轻儿子”(亲儿子?!),这里“重儿子”表示的是size域极大的子结点

图上用粗线标记的边称作“重边”,由若干条重边连成的边叫“重链”

这里有两个重要的性质:

  1. 如果(v,u)为轻边,则
  2. 从根到某一点的路径上轻链、重链的个数都不大于

上面第二条性质告诉了我们,这个算法的时间复杂度为对数级

显然可以dfs求出size,dep,pre,top,son,id,但是大数据可能出现爆栈的情况,于是我们可以用3次循环来求。

第一次正向循环求出pre,dep,再逆向循环求出size,再正向循环求出top,son,id。然后就可以将权值建立到线段树中

建完树就要开始我们核心算法了,比如更新路径,设时,若,更新的权值,然后.当时,再更新u到v的权值即可,操作均为logn的。

题解:http://blog.csdn.net/mlzmlz95/article/details/8993673

AC-Code:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
using namespace std;
const int MAXN=60005;
vector<int> T[MAXN];
int val[MAXN],pre[MAXN],siz[MAXN],son[MAXN],top[MAXN],id[MAXN],w[MAXN],tot;
int q[MAXN],deep[MAXN],ql,qr;
bool vis[MAXN];
int MAX[MAXN*2],SUM[MAXN*2];
typedef vector<int>::iterator it;
int n,m,a,b;
inline int getlc(int x)
{
    return x*2;
}
inline int getrc(int x)
{
    return x*2+1;
}
void update(int x)
{
    MAX[x]=max(MAX[getlc(x)],MAX[getrc(x)]);
    SUM[x]=SUM[getlc(x)]+SUM[getrc(x)];
}
void build(int rt,int l,int r)
{
    if (l==r)
    {
        MAX[rt]=SUM[rt]=w[l];
        return;
    }
    int mid=(l+r)/2;
    build(getlc(rt),l,mid);
    build(getrc(rt),mid+1,r);
    update(rt);
}
void edit(int x,int l,int r,int u,int val)
{
    if (l==r)
    {
        MAX[x]=SUM[x]=val;
        return;
    }
    int mid=(l+r)/2;
    if (u<=mid)
        edit(getlc(x),l,mid,u,val);
    else
        edit(getrc(x),mid+1,r,u,val);
    update(x);
}
int QUERY_A(int x,int l,int r,int ql,int qr)
{
    if (ql<=l&&qr>=r)
        return MAX[x];
    int mid=(l+r)/2;
    int tmp1=-99999999,tmp2=-99999999;
    if (ql<=mid)
        tmp1=QUERY_A(getlc(x),l,mid,ql,qr);
    if (qr>mid)
        tmp2=QUERY_A(getrc(x),mid+1,r,ql,qr);
    return max(tmp1,tmp2);
}
int QUERY_B(int x,int l,int r,int ql,int qr)
{
    if (ql<=l&&qr>=r)
        return SUM[x];
    int mid=(l+r)/2;
    int tmp1=0,tmp2=0;
    if (ql<=mid)
        tmp1=QUERY_B(getlc(x),l,mid,ql,qr);
    if (qr>mid)
        tmp2=QUERY_B(getrc(x),mid+1,r,ql,qr);
    return tmp1+tmp2;
}
int GetMax(int a,int b)
{
    int tmp=-99999999;
    while (top[a]!=top[b])
    {
        if (deep[top[a]]<deep[top[b]])
            swap(a,b);
        tmp=max(tmp,QUERY_A(1,1,n,id[top[a]],id[a]));
        a=pre[top[a]];
    }
    if (id[a]>id[b])
        swap(a,b);
    tmp=max(tmp,QUERY_A(1,1,n,id[a],id[b]));
    return tmp;
}
int GetSum(int a,int b)
{
    int tmp=0;
    while (top[a]!=top[b])
    {
        if (deep[top[a]]<deep[top[b]])
            swap(a,b);
        tmp+=QUERY_B(1,1,n,id[top[a]],id[a]);
        a=pre[top[a]];
    }
    if (id[a]>id[b])
        swap(a,b);
    tmp+=QUERY_B(1,1,n,id[a],id[b]);
    return tmp;
}
int main()
{
    char buf[20];
    scanf("%d",&n);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d",&a,&b);
        T[a].push_back(b);
        T[b].push_back(a);
    }
    for (int i=1;i<=n;i++)
        scanf("%d",&val[i]);
    vis[deep[1]=q[0]=1]=true;
    it k;
    for (;ql<=qr;ql++)
        for (k=T[q[ql]].begin();k!=T[q[ql]].end();k++)
            if (!vis[*k])
            {
                vis[*k]=true;
                deep[q[++qr]=(*k)]=deep[q[ql]]+1;
                pre[*k]=q[ql];
            }
    for (int i=qr;i>=0;i--)
    {
        siz[pre[q[i]]]+=++siz[q[i]];
        if (siz[q[i]]>siz[son[pre[q[i]]]])
            son[pre[q[i]]]=q[i];
    }
    for (int i=0;i<=qr;i++)
        if (!top[q[i]])
            for (int j=q[i];j;j=son[j])
            {
                top[j]=q[i];
                w[id[j]=++tot]=val[j];
            }
	build(1,1,n);
    scanf("%d",&m);
    while (m--)
    {
        scanf("%s%d%d",buf,&a,&b);
        if (buf[1]=='M')
            printf("%d\n",GetMax(a,b));
        else if (buf[1]=='H')
            edit(1,1,n,id[a],b);
        else
            printf("%d\n",GetSum(a,b));
    }
    return 0;
}