抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

本文正在施工中,有缺失的内容请谅解。

本文对简单的树形问题进行了讲解。

有关树的问题在考试时非常常见,所以必须熟练掌握。本文介绍的问题都比较简单,不涉及什么高难的问题。

树的性质与遍历

我们知道,一棵树有 nn 个点,n1n-1 条边,且一定是连通的。有几种特殊的树:

  • 链:树退化成链式结构。
  • “菊花图”:树的深度恰好为 22

[Luogu P5908] 猫猫和企鹅

可以简单的使用树的深度优先遍历来解决问题。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, d, ans = 0, dis[100005];
vector <int> G[100005];

void dfs(int o, int f) {
    if (dis[o] <= d && o != 1) ++ans;
    for (int i = 0; i < G[o].size(); ++i) {
        int &y = G[o][i];
        if (y == f) continue;
        dis[y] = dis[o] + 1;
        dfs(y, o);
    }
}

int main(void) {
    scanf("%d%d", &n, &d);
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    cout << ans << endl;
    return 0;
}

边权转点权。实际上在树形问题中,边权非常不好处理,所以我们可以在深度优先遍历的时候将边权全放给儿子。大概像这样:

d[y] = w; // 转移边权为儿子的点权
dfs(y, x);

树的直径

指的是树上的最长路径,可以通过两次 DFS 求出。第一次 DFS 从任意节点开始遍历,走到最远的地方,然后从这个地方开始第二次 DFS,走到最远的地方。这两个最远的地方连接起来就是树的直径。

模板,代码如下:

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>

using namespace std;

int n, maxx = 0;
bool v[10005];
int dis[10005];
vector <int> G[10005];

void dfs(int o, int fa) {
    for (int i = 0; i < G[o].size(); ++i) {
        if (G[o][i] == fa) continue;
        dis[G[o][i]] = dis[o] + 1;
        if (dis[G[o][i]] > dis[maxx]) maxx = G[o][i];
        dfs(G[o][i], o);
    }
}

int main(void) {
    scanf("%d", &n);
    for (int i = 1; i < n; ++i) {
        int u, v; scanf("%d%d", &u, &v);
        G[u].push_back(v); G[v].push_back(u);
    } dfs(1, -1);
    memset(v, 0, sizeof(v));
    dis[maxx] = 0;
    dfs(maxx, -1);
    printf("%d\n", dis[maxx]);
    return 0;
}

树的直径有一个显然的性质:直径的某个端点到所有点的距离的最小值一定是所有点中最大的。

树的重心

对于树上的每一个点,计算其所有子树中最大的子树节点数,使得这个值最小的点就是这棵树的重心。树的重心有以下性质:

以树的重心为根时,所有子树的大小都不超过整棵树大小的一半。

使用反证法。设当前的重心为 uu,与 uu 相连的子树 vv 的大小超过了整棵树的一半,那么将 vv 替换为树的重心,显然这时 uu 的子树不超过整棵树大小的一半,而 vv 的子树大小减小了 11,一定比 uu 作为重心更好。

树中所有点到某个点的距离和中,到重心的距离和是最小的;如果有两个重心,那么到它们的距离和一样。

因为如果移动了,增加的距离一定大于等于减少的距离。

把两棵树通过一条边相连得到一棵新的树,那么新的树的重心在连接原来两棵树的重心的路径上。

如果不在这条路径上,那么只有那个节点的子树的代价会减小,其它的都会增加,得不偿失。

在一棵树上添加或删除一个叶子,那么它的重心最多只移动一条边的距离。

增加或减少一个叶子,只能使最大的子树恰好比一半多 11,重心只移动 11 即可。


现在我们来看如何求出树的重心。我们假定 11 为根节点,然后设 size[x]size[x] 代表 xx 的子树大小。我们定义 max_part 为当前 dfs 到的节点中,最大的子树大小。它的孩子们的子树大小在 dfs 时就可以统计,而剩下的一棵子树就是它父亲对应的子树,这就是 nsize[x]n-size[x]。这样只需要调用一次 dfs,时间复杂度为 O(n)O(n)

int n, pos, ans = 1e7; // pos 为重心,ans 为重心对应的最大子树
int s[105];
vector <int> G[105];

void dfs(int x, int fa)
{
    s[x] = 1;
    int max_part = 0;
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i];
        if (y == fa) continue; // 想逃回父亲,直接枪毙
        dfs(y, x);
        s[x] += s[y]; // 父节点的子树大小加上子节点的
        max_part = max(max_part, s[y]); // 更新 max_part
    }
    max_part = max(max_part, n - s[x]); // 最后一棵子树是父亲节点对应的子树(这里的子树是指以 x 为根的情况)
    if (max_part < ans) // 答案更优就更新
    {
        ans = max_part;
        pos = x;
    }
}

学过树形 DP 的读者应该可以发现这个东西类似于换根 DP,但又不太一样。

最近公共祖先(LCA)

LCA 是指点集的 LCA,为了方便,我们记某点集 A={ui1in}A=\{u_i|1\leqslant i\leqslant n\} 的最近公共祖先为 LCA(u1,u2,,un)LCA(u_1,u_2,\ldots,u_n)LCA(A)LCA(A)。含义是离它们最近的一个点,是它们所有点的祖先。

LCA 有以下性质:

  1. LCA(u)=uLCA(u)=u
  2. LCA(u,v)=uLCA(u,v)=u 的充要条件是 uuvv 的祖先;
  3. 如果 uu 不为 vv 的祖先并且 vv 不为 uu 的祖先,那么 u,vu,v 分别处于 LCA(u,v)LCA(u,v) 的两棵不同子树中;
  4. 给定一棵二叉树,前序遍历中,LCA(S)LCA(S) 出现在所有 SS 中元素之前,后序遍历中 LCA(S)LCA(S) 则出现在所有 SS 中元素之后;
  5. 两点集并的最近公共祖先为两点集分别的最近公共祖先的最近公共祖先,即 LCA(AB)=LCA(LCA(A),LCA(B))LCA(A \cup B) = LCA(LCA(A),LCA(B))
  6. 两点的最近公共祖先必定处在树上两点间的最短路上,且 dist(u,v)=h(u)+h(v)2h(LCA(u,v))dist(u,v)=h(u)+h(v)-2h(LCA(u,v)),其中 h(x)h(x)xx 到树根的距离。

这些性质都比较显然,在此不做证明。


现在我们来讨论 LCA 的求法。

LCA 有多种求法,不同情况要用不同的方法。
可以在 模板 进行测试。

向上标记法

比如我们现在要求 LCA(u,v)LCA(u,v),我们可以先让 uuvv 向上跳到同一深度,然后让它们一起往上调,一定可以找到它们的 LCA。

int n, m, root;
vector <int> G[500005];
bool v[500005];

struct node {
    int p, fa, dep;
}T[500005];

void dfs(int o, int deep) {
    v[o] = 1; T[o].dep = deep;
    for (int i = 0; i < G[o].size(); ++i)
        if (!v[G[o][i]]) {
            T[G[o][i]].fa = o;
            dfs(G[o][i], deep + 1);
        }
}

int LCA(int x, int y) {
    if (T[x].dep < T[y].dep) swap(x, y);
    while (T[x].dep > T[y].dep) x = T[x].fa; // 跳到同一深度
    if (x == y) return x; // 此处特判可以略去,但习惯写上
    while (x != y) x = T[x].fa, y = T[y].fa; // 一起往上跳
    return x;
}

int main(void) {
    n = read(), m = read(), root = read();
    for (int i = 1; i < n; ++i) {
        int x = read(), y = read();
        G[x].push_back(y); G[y].push_back(x);
    }
    T[root].fa = -1;
    dfs(root, 1); // 构造树
    
    while (m--) {
        int a = read(), b = read();
        cout << LCA(a, b) << endl;
    }
    return 0;
}

树上倍增法

以上做法极慢,最常用的快速求 LCA 的方法是树上倍增法。设 f[x,k]f[x,k] 表示 xx2k2^k 辈祖先。若该节点不存在,则令 f[x,k]=1f[x,k]=-1(不设为 00 的原因是有的题需要设一个编号为 00 的虚拟节点)。那么有 f[x,k]=f[f[x,k1]][k1]f[x,k]=f[f[x,k-1]][k-1],接下来的思路跟向上标记法大致相同。

在求解 LCA 时,我们先让它们都跳到同一深度。如果此时两个点已经相等,那么这个点就是 LCA(此步不能略去,原因接下来会说明)。然后我们尝试着让它们一起往上跳,如果跳完值还不相等,那一定跳。最后再跳一步即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

inline int read(void) {
    int x = 0, c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x;
}

int n, m, root;
int dep[500005];
int lg[500005];
int f[500005][20];
vector <int> G[500005];

void dfs(int o, int fa)
{
    f[o][0] = fa; // 根据定义
    dep[o] = dep[fa] + 1; // 深度为父亲 +1
    for (int i = 1; i <= lg[n]; ++i) // 跳出树的值都会变成 -1
        f[o][i] = f[f[o][i - 1]][i - 1];
    for (int i = 0; i < G[o].size(); ++i)
        if (G[o][i] != fa) dfs(G[o][i], o); // 如果不往父亲回,就以 G[o][i] 为儿子,o 为父亲 dfs
}

int LCA(int x, int y)
{
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = lg[n]; i >= 0; --i) // 从一个最大可能值开始枚举,这样做的正确性基于二进制拆分
        if (dep[f[x][i]] >= dep[y]) x = f[x][i]; // 如果跳这么大深度依然比 y 大,那只能跳
    if (x == y) return x; // 此步不能省去,否则已经是 LCA,最后 return 时还会再跳一次
    for (int i = lg[n]; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i]; // 如果跳这么大都不相等,此时必须要跳
    return f[x][0]; // 最后再跳一步便一定是 LCA
}

int main(void)
{
    scanf("%d%d%d", &n, &m, &root);
    lg[1] = 0;
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i < n; ++i)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(root, -1); // 让 -1 作为根节点的“父亲“,使得 f 数组中跳出树的都变成 -1
    while (m--)
    {
        int x, y;
        scanf("%d%d", &x, &y);
        printf("%d\n", LCA(x, y));
    }
    return 0;
}

除了求解 LCA 问题,很多树上问题都会用到树上倍增法。

树上前缀和与差分

前缀和和差分是线性结构上的有力工具,但是它们也可以搬到树上来。

树上前缀和

SiS_i 代表根节点到节点 ii 的权值总和,那么:

  • 如果是边权,那么 d(x,y)=Sx+Sy2×SLCA(x,y)d(x,y)=S_x+S_y-2\times S_{LCA(x,y)}
  • 如果是点权,那么 d(x,y)=Sx+SySLCA(x,y)Sfa[LCA(x,y)]d(x,y)=S_x+S_y-S_{LCA(x,y)}-S_{fa[LCA(x,y)]}(因为 LCA 处只能减一次)。

树上点差分

也就是对于点权的树上差分。也就是说,给定若干条路经,求出每个点经过的次数。那么:

// s -> t
d[s]++, d[t]++;
d[lca(s, t)]--, d[fa[lca(s, t)]]--;

当然不同的数值也可以改。

模板

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, k, lg[50005], dep[50005];
int f[50005][18], sum[50005];
vector <int> G[50005];

void dfs(int x, int fa)
{
    f[x][0] = fa;
    dep[x] = dep[fa] + 1;   
    for (int i = 1; i <= lg[n]; ++i)
        f[x][i] = f[f[x][i - 1]][i - 1];
    for (int i = 0; i < G[x].size(); ++i)
        if (G[x][i] != fa) dfs(G[x][i], x);
}

int LCA(int x, int y)
{
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = lg[n]; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
    if (x == y) return x;
    for (int i = lg[n]; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}

void get(int x, int fa)
{
    for (int i = 0; i < G[x].size(); ++i)
    {
        int y = G[x][i];
        if (y == fa) continue;
        get(y, x);
        sum[x] += sum[y];
    }
}

int main(void)
{
    scanf("%d%d", &n, &k);
    for (int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        G[u].push_back(v), G[v].push_back(u);
    }
    for (int i = 1; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    dfs(1, 0);
    while (k--)
    {
        int s, t, lca;
        scanf("%d%d", &s, &t);
        lca = LCA(s, t);
        sum[s]++, sum[t]++;
        sum[lca]--, sum[f[lca][0]]--;
    }
    get(1, 0);
    int ans = 0;
    for (int i = 1; i <= n; ++i)
        ans = max(ans, sum[i]);
    printf("%d\n", ans);
    return 0;
}

树上边差分

还是直接将边前缀和搬过来:

d[s]++, d[t]++;
d[lca(s, t)] -= 2;

DFS 序列

树在进行 DFS 时,会有入栈出栈的顺序,而且每一个树恰好入栈一次、出栈一次。这样产生的序列就是树的 欧拉序。如果只记录一次节点,那么产生的是 DFS 序,访问到顺序记为时间戳 dfn。

概述

比如这样一棵树:

它的欧拉序就是 1 4 4 2 6 8 8 6 5 5 2 3 7 7 3 1

可以发现,欧拉序有以下性质:

  • 若树的大小为 nn,那么欧拉序的长度就等于 2n2n,每个数恰好出现了两次。
  • 每棵子树 xx 在欧拉序中一定是连续的一段,节点 xx 一定同时在这个连续段的两端。

而 DFS 序可以与树上差分结合起来,实现满足差分信息的树上信息高效维护。下面我们来看 DFS 序的应用:

单点修改

模板

单点增加,查询子树和。

根据 DFS 序列的性质,我们可以将树上信息转化到链上来维护。怎么转呢?可以发现,子树一定是在根后面连着的,那么我们记录 sizsiz 大小就可以了。

接下来就是 Fenwick 树了。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#define lowbit(x) (x & -x)

using namespace std;
typedef long long i64;

int n, m, root;
int a[1000005], siz[1000005];
int dfn[1000005], num = 0;
vector <int> G[1000005];

i64 C[1000005];
void add(int x, int k) {
    while (x <= n) {
        C[x] += k;
        x += lowbit(x);
    }
}
i64 sum(int x) {
    i64 res = 0;
    while (x) {
        res += C[x];
        x -= lowbit(x);
    }
    return res;
}

void dfs(int x, int fa) {
    dfn[x] = ++num; siz[x] = 1;
    add(num, a[x]);
    for (auto y : G[x]) {
        if (y == fa) continue;
        dfs(y, x);
        siz[x] += siz[y];
    }
}

int main(void) {
    scanf("%d%d%d", &n, &m, &root);
    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].emplace_back(v);
        G[v].emplace_back(u);
    }
    dfs(root, 0);
    while (m--) {
        int op, a, x;
        scanf("%d%d", &op, &a);
        if (op == 1) {
            scanf("%d", &x);
            add(dfn[a], x);
        } else {
            printf("%lld\n", sum(dfn[a] + siz[a] - 1) - sum(dfn[a] - 1));
        }
    }
    return 0;
}

子树修改

模板

子树所有节点增加 xx,子树节点和。

实际上是一样的,我们只需要将树状数组替换为线段树(当然,利用拆分信息树状数组也可以做)。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;
typedef long long i64;

int n, m, root;
int a[1000005], siz[1000005];
int dfn[1000005], idx[1000005], num = 0;
vector <int> G[1000005];

i64 T[4000005];
int tag[4000005];

void build(int o, int l, int r) {
    if (l == r) return T[o] = a[idx[l]], void();
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid + 1, r);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
inline void pushdown(int o, int l, int r) {
    if (!tag[o]) return;
    int mid = l + r >> 1;
    T[o << 1] += 1ll * (mid - l + 1) * tag[o], T[o << 1 | 1] += 1ll * (r - mid) * tag[o];
    tag[o << 1] += tag[o], tag[o << 1 | 1] += tag[o];
    tag[o] = 0;
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) return T[o] += 1ll * (r - l + 1) * k, tag[o] += k, void();
    int mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
i64 query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1; i64 res = 0; pushdown(o, l, r);
    if (x <= mid) res += query(o << 1, l, mid, x, y);
    if (mid < y) res += query(o << 1 | 1, mid + 1, r, x, y);
    return res;
}

void dfs(int x, int fa) {
    dfn[x] = ++num; idx[num] = x; siz[x] = 1; 
    for (auto y : G[x]) {
        if (y == fa) continue;
        dfs(y, x);
        siz[x] += siz[y];
    }
}

int main(void) {
    scanf("%d%d%d", &n, &m, &root);
    for (int i = 1; i <= n; ++i)
        scanf("%d", a + i);
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].emplace_back(v);
        G[v].emplace_back(u);
    }
    dfs(root, 0);
    build(1, 1, n);
    while (m--) {
        int op, a, x;
        scanf("%d%d", &op, &a);
        if (op == 1) {
            scanf("%d", &x);
            update(1, 1, n, dfn[a], dfn[a] + siz[a] - 1, x);
        } else {
            printf("%lld\n", query(1, 1, n, dfn[a], dfn[a] + siz[a] - 1));
        }
    }
    return 0;
}

链上修改

这里简单提一下,有兴趣可以写一下模板

我们说过,DFS 序维护的依旧是前缀和,所以利用树上差分的方式,配合树状数组可以快速修改与查询,会比重链剖分少一个 log\log

但是当维护的内容不满足差分的区间可减性,DFS 序就做不了了。

快速 LCA

DFS 序求 LCA 是常用方式中最快的 LCA 算法,并且是在线的。可以做到 O(nlogn)O(n\log n) 预处理,O(1)O(1) 查询。而欧拉序求 LCA 则有 22 倍的常数,至于利用笛卡尔树的 O(n)O(n) 预处理做法则并不使用,因为大部分树上问题都是带 log\log 的。尤其是对于虚树这询问 LCA 次数极多的东西,DFS 序的优势很大。

考虑树上的两个节点 u,vu,v 和其 LCA dddd 显然在 u,vu,v 之前出现。

如果 uu 不是 vv 的祖先,我们只需要求出 uuvv 的 DFS 序之间深度最小的一个节点,它的父亲就是 dd
如果是,那么直接令 uu 变成 dfn 比它大 11 的节点,就转化成了上一种情况。

只需要特判掉 u=vu=v,这样就可以直接改变 uu 的 DFS 序了。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 5e5 + 5;

int n, m, root, num, dfn[N], dep[N], lg[N], mi[20][N];
vector<int> G[N];

inline int get(int x, int y) { return dep[x] < dep[y] ? x : y; } // 这个 inline 有用
void dfs(int x, int fa) {
    mi[0][dfn[x] = ++num] = fa, dep[x] = dep[fa] + 1;
    for (int y : G[x]) if (y != fa) dfs(y, x); 
}
int LCA(int u, int v) {
    if (u == v) return u;
    if ((u = dfn[u]) > (v = dfn[v])) swap(u, v);
    int d = lg[v - u];
    return get(mi[d][u + 1], mi[d][v - (1 << d) + 1]);
}

int main(void) {
    scanf("%d%d%d", &n, &m, &root);
    for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        G[u].push_back(v), G[v].push_back(u);
    } dfs(root, 0);
    for (int i = 1; i <= lg[n]; i++)
        for (int j = 1; j + (1 << i) - 1 <= n; j++)
            mi[i][j] = get(mi[i - 1][j], mi[i - 1][j + (1 << i - 1)]);
    for (int i = 1, u, v; i <= m; i++) scanf("%d%d", &u, &v), printf("%d\n", LCA(u, v));
    return 0;
}

树链与树链剖分

我们学过的很多内容,比如线段树,都只能处理序列,也就是链上的问题。当它跑到了树上,我们还可以使用 DFS 序列来进行处理。但是当维护的信息不满足差分性质后,DFS 序就显得无力了。这是候怎么办?要对树链进行处理了。

概念

我们先简单介绍一下相关概念。

树链是指树上的一条链,树链剖分就是将整棵树剖分成若干条链,使它组合成线性结构,然后可以在线性结构上工作的强大数据结构就可以派上用场了,也简称树剖。

树链剖分有多种形式,比如重链剖分长链剖分实链剖分,其中重链剖分最为常用,大部分树剖指的都是它。

这里只介绍最常用的重链剖分,长链剖分请参考《复杂树形问题》。

重链剖分

我们先来看一下这个问题:

你需要写一种数据结构,支持区间修改和区间查询。

这个我当然会!直接线段树敲上去不久完事了嘛!

那么再来一个:

给出一棵树,支持链上修改,链上查询,子树修改,子树查询。

这是什么

这个题是有的,模板

我们不会这种题,我们只会在链上做,那就需要使用树链剖分。

重链剖分可以将树上的任意一条路径划分成长度不超过 O(logn)O(\log n) 的连续链,每条链上的点深度互不相同(即自底向上的一条链,链上所有点的 LCA 为链的深度小的那个端点)。

我们首先需要了解几个概念:

  • 子树的大小:子树中节点的个数。
  • 重子节点(重节点):表示其子节点中子树最大的子结点。如果有多个子树最大的子结点,取其一。如果没有子节点,就无重子节点。
  • 轻子节点(轻节点):除了重子节点外的所有节点。特别地,树根是轻节点。
  • 重边:连接节点到它的重子节点的边。
  • 轻边:除了重边之外的所有边。
  • 重链:若干条首尾连接的重边,也就是说,一条重链的开头是一个轻节点,剩下的都是重节点。特别地,一个落单的节点也是重链。

这样整棵树就被剖分成若干条重链,可以证明链的规模是 O(logn)O(\log n) 的。

实现时要通过两次 dfs,大概像这样:

int f[100005], son[100005], top[100005]; // son 指重儿子,没有为 -1;top 指重链的顶端节点的标号
int dep[100005], siz[100005], dfn[100005], tot = 0; // dfn 指时间戳
vector <int> G[100005];

void dfs1(int x, int fa) {
    dep[x] = dep[fa] + 1; f[x] = fa; siz[x] = 1;
    for (int y : G[x])
        if (y != fa) {
            dfs1(y, x); siz[x] += siz[y];
            if (siz[y] > siz[son[x]]) son[x] = y; 
        }
}

void dfs2(int x, int topf) // topf 记录这个重链的顶点
{
    dfn[x] = ++tot; top[x] = topf;
    if (son[x] == -1) return;
    dfs2(son[x], topf); // 优先处理重链
    for (int y : G[x])
        if (y != f[x] && y != son[x]) dfs2(y, y); // 遍历轻儿子
}

注意到重链的处理总是优先的,也就是说,重链内的时间戳编号是连续的,那么就决定了我们维护的时候直接对应了一段序列上的区间。

建立一棵线段树,以 xxyy 的最短路径上加上 zz 为例,像这样:

while (top[x] != top[y]) // 如果它们不在一条重链上
{
    if (dep[top[x]] < dep[top[y]]) swap(x, y); // 要计算深的节点
    update(1, 1, n, dfn[top[x]], dfn[x], z); // 更新
    x = f[top[x]]; // 跳上来,注意重链的头部已经修改过了,跳到重链头的父亲上
}
if (dep[x] > dep[y]) swap(x, y);
update(1, 1, n, dfn[x], dfn[y], z); // 现在已经在一条重链上,更新

发没发现这个处理特别像 LCA 的求解?没错,LCA 确实也可以用树链剖分来求解,但是不如倍增简单,除非恰好这道题目需要用到树链剖分,我们才会使用树链剖分来求解 LCA。

那么对于子树的操作呢?由于是深度优先遍历,所以一棵子树内的时间戳编号也是连续的(依然具有 DFS 序的性质),也可以直接使用线段树维护:

update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, k); // 利用子树的大小直接计算

那么完整代码就很简单了:

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>

using namespace std;

int n, m, root, P;
int w[100005], f[100005], son[100005], top[100005];
int dep[100005], siz[100005], dfn[100005], tot = 0, a[100005];
vector <int> G[100005];

void dfs1(int x, int fa)
{
    dep[x] = dep[fa] + 1; f[x] = fa; siz[x] = 1;
    int maxx = -1;
    for (auto y : G[x])
        if (y != fa)
        {
            dfs1(y, x);
            siz[x] += siz[y];
            if (siz[y] > maxx) son[x] = y, maxx = siz[y];
        }
}

void dfs2(int x, int topf)
{
    dfn[x] = ++tot; top[x] = topf; a[tot] = w[x];
    if (son[x] == -1) return;
    dfs2(son[x], topf);
    for (auto y : G[x])
        if (y != f[x] && y != son[x]) dfs2(y, y);
}

int T[400005], tag[400005];
inline void maintain(int o) { T[o] = (T[o << 1] + T[o << 1 | 1]) % P; }
inline void pushdown(int o, int l, int r)
{
    if (!tag[o]) return;
    int mid = l + r >> 1, ls = o << 1, rs = o << 1 | 1;
    tag[ls] = (tag[ls] + tag[o]) % P, tag[rs] = (tag[rs] + tag[o]) % P;
    T[ls] = (T[ls] + 1ll * tag[o] * (mid - l + 1)) % P, T[rs] = (T[rs] + 1ll * tag[o] * (r - mid)) % P;
    tag[o] = 0;
}
void build(int o, int l, int r)
{
    if (l == r) {
        T[o] = a[l] % P;
        return;
    }
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid + 1, r);
    maintain(o);
}
void update(int o, int l, int r, int x, int y, int k)
{
    if (x <= l && r <= y) {
        T[o] = (T[o] + 1ll * k * (r - l + 1)) % P;
        tag[o] = (tag[o] + k) % P;
        return;
    }
    pushdown(o, l, r);
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k);
    maintain(o);
}
int query(int o, int l, int r, int ql, int qr)
{
    if (ql <= l && r <= qr) return T[o];
    pushdown(o, l, r);
    int mid = l + r >> 1, res = 0;
    if (ql <= mid) res = (res + query(o << 1, l, mid, ql, qr)) % P;
    if (mid < qr) res = (res + query(o << 1 | 1, mid + 1, r, ql, qr)) % P;
    return res;
}

int main(void)
{
    memset(son, 0xff, sizeof(son));
    scanf("%d%d%d%d", &n, &m, &root, &P);
    for (int i = 1; i <= n; ++i) scanf("%d", w + i);
    for (int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        G[u].push_back(v), G[v].push_back(u);
    }
    dfs1(root, 0);
    dfs2(root, root);
    build(1, 1, n);
    while (m--)
    {
        int k, x, y, z;
        scanf("%d", &k);
        if (k == 1)
        {
            scanf("%d%d%d", &x, &y, &z);
            z %= P;
            while (top[x] != top[y])
            {
                if (dep[top[x]] < dep[top[y]]) swap(x, y);
                update(1, 1, n, dfn[top[x]], dfn[x], z);
                x = f[top[x]];
            }
            if (dep[x] > dep[y]) swap(x, y);
            update(1, 1, n, dfn[x], dfn[y], z);
        }
        else if (k == 2)
        {
            scanf("%d%d", &x, &y);
            int ans = 0;
            while (top[x] != top[y])
            {
                if (dep[top[x]] < dep[top[y]]) swap(x, y);
                ans = (ans + query(1, 1, n, dfn[top[x]], dfn[x])) % P;
                x = f[top[x]];
            }
            if (dep[x] > dep[y]) swap(x, y);
            ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % P;
            printf("%d\n", ans);
        }
        else if (k == 3)
        {
            scanf("%d%d", &x, &k);
            update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, k);
        }
        else
        {
            scanf("%d", &x);
            printf("%d\n", query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1));
        }
    }
    return 0;
}

算法理论实践复杂度为 O(nlog2n)O(n\log^2 n),但实际上远远达不到上界,重链剖分的常数很小。

但即使如此,树链剖分也是比 DFS 序慢的。如果能使用 DFS 序,更推荐用它。

树链剖分求 LCA

之前维护链的过程很像求 LCA 往上跳的过程,所以重链剖分也可以用来求 LCA。

int LCA(int x, int y) {
    while (top[x] != top[y]) { // 不在一条重链上
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = f[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

一般来讲树链剖分会比倍增算法快一点,常数也很小(想想那个二维数组吧),甚至不会比 O(1)O(1) LCA 慢(因为有寻址)。

Problemset

我们来看一些有趣的题目,前面的一些题目都很简单,后面的部分题目相当复杂。

简单树形问题

不涉及什么算法,只需要树的有关知识,以及求解树的直径和树的重心的方法等即可。

[JLOI2012] 树

Portal.

直接 dfs 遍历树即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, c, ans = 0, a[100005];
vector <int> son[100005];

void dfs(int o, int sum)
{
    if (sum == c) return ++ans, void();
    if (sum > c) return;
    for (auto x : son[o])
        dfs(x, sum + a[x]);
}

int main(void)
{
    scanf("%d%d", &n, &c);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    for (int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        son[u].push_back(v);
    }
    for (int i = 1; i <= n; ++i)
        dfs(i, a[i]);
    printf("%d\n", ans);
    return 0;
}

[YsOI2020] 植树

Portal.

类似于求树的重心,不过要统计的是子树的大小。注意 11 号节点肯定是可以的。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, s[1000005];
bool flag[1000005];
vector <int> G[1000005];

void dfs(int x, int fa)
{
    int num = 0;
    s[x] = 1;
    for (int i = 0; i < G[x].size(); ++i)
    {
        int y = G[x][i];
        if (y == fa) continue;
        dfs(y, x);
        s[x] += s[y];
        if (!num) num = s[y];
        if (num != s[y]) flag[x] = true;
    }
    if (x != 1 && num && num != n - s[x]) flag[x] = true;
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        G[u].push_back(v), G[v].push_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; ++i)
        if (!flag[i]) printf("%d ", i);
    putchar('\n');
    return 0;
}

[NOIP2014 提高组] 联合权值

Portal.

由于距离为 22,所以枚举每一个点,与它相邻的点两两互为联合点,然后使用 nn 项式的平方计算即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

#define MOD 10007

using namespace std;

inline int read(void)
{
    int x = 0, c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x;
}

int n;
int W[200005];
vector <int> G[200005];

int main(void)
{
    n = read();
    for (int i = 1; i < n; ++i)
    {
        int u = read(), v = read();
        G[u].push_back(v);
        G[v].push_back(u);
    }
    for (int i = 1; i <= n; ++i) W[i] = read();
    int anssum = 0, ansmax = 0;
    for (int i = 1; i <= n; ++i)
    {
        int res = 0, ret = 0;
        int maxx1 = 0, maxx2 = 0;
        for (int j = 0; j < G[i].size(); ++j)
        {
            #define pocket W[G[i][j]]
            if (pocket > maxx1) maxx2 = maxx1, maxx1 = pocket;
            else if (pocket > maxx2) maxx2 = pocket;
            res = (res + pocket) % MOD;
            ret = (ret + pocket * pocket) % MOD;
        }
        anssum = (anssum + res * res % MOD - ret + MOD) % MOD;
        ansmax = max(ansmax, maxx1 * maxx2);
    }
    printf("%d %d\n", ansmax, anssum);
    return 0;
}

[NOI2011] 道路修建

Portal.

我们采用类似求树的重心的方法。统计一条边的费用时,一部分大小是子树的大小,另一部分是 nn 减去这个子树的大小。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

#define i64 long long

using namespace std;

struct edge
{
    int v, d;
    edge(int v = 0, int d = 0) :
        v(v), d(d) {}
};

int n;
i64 ans;
int s[1000005];
vector <edge> G[1000005];

void dfs(int x, int fa)
{
    s[x] = 1;
    for (int i = 0; i < G[x].size(); ++i)
    {
        int y = G[x][i].v, w = G[x][i].d;
        if (y == fa) continue;
        dfs(y, x);
        ans += (i64)w * abs(s[y] - (n - s[y]));
        s[x] += s[y];
    }
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i < n; ++i)
    {
        int u, v, d;
        scanf("%d%d%d", &u, &v, &d);
        G[u].push_back(edge(v, d));
        G[v].push_back(edge(u, d));
    }
    dfs(1, 0);
    printf("%lld\n", ans);
    return 0;
}

[Luogu P1395] 会议

Portal.

树中所有点到某个点的距离和中,到重心的距离和是最小的。我们只需要先求出树的重心,然后用 dfs 计算距离即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, ans = 1e7, pos, s[50005];
vector <int> G[50005];

void dfs(int x, int fa)
{
    s[x] = 1;
    int max_part = 0;
    for (int i = 0; i < G[x].size(); ++i)
    {
        int y = G[x][i];
        if (y == fa) continue;
        dfs(y, x);
        s[x] += s[y];
        max_part = max(max_part, s[y]);
    }
    max_part = max(max_part, n - s[x]);
    if (max_part < ans || (max_part == ans && x < pos))
    {
        pos = x;
        ans = max_part;
    }
}

int sum = 0, dep[50005];
void dfs2(int x, int fa)
{
    dep[x] = dep[fa] + 1;
    sum += dep[x];
    for (int i = 0; i < G[x].size(); ++i)
    {
        int y = G[x][i];
        if (y == fa) continue;
        dfs2(y, x);
    }
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i < n; ++i)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs(1, 0);
    dep[0] = -1;
    dfs2(pos, 0);
    printf("%d %d\n", pos, sum);
    return 0;
}

LCA 的综合应用

包括 LCA、树上差分等内容。有的时候也会与别的算法综合(比如二分),但不涉及高难的算法。

[Luogu P3938] 斐波那契

Portal.

树的规模很大,但是深度很小,我们考虑不建树使用向上标记法。由于 dep 肯定是编号越大的节点越大,因此现在的问题就是有如何求解一个节点的爸爸。

发现节点和父亲的差都是斐波那契数,直接二分即可。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;
typedef long long i64;
const i64 MAXL = 1e12;

i64 f[105] = {0, 1};

int main(void)
{
    int i;
    for (i = 2;; ++i)
    {
        f[i] = f[i - 1] + f[i - 2];
        if (f[i] > MAXL) break;
    }

    int n;
    scanf("%d", &n);
    while (n--)
    {
        i64 a, b;
        scanf("%lld%lld", &a, &b);
        while (a != b)
        {
            if (a > b) swap(a, b);
            b -= (*(lower_bound(f, f + i + 1, b) - 1));
        }
        printf("%lld\n", a);
    }
    return 0;
}

[JLOI2014] 松鼠的新家

Portal.

直接使用树上点差分即可,但是要注意处理重复的情况。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;

int n, a[300005], dep[300005];
int f[300005][35], sum[300005];
vector <int> G[300005];

void dfs(int x, int fa)
{
    f[x][0] = fa;
    dep[x] = dep[fa] + 1;   
    for (int i = 1; i <= 30; ++i)
        f[x][i] = f[f[x][i - 1]][i - 1];
    for (auto y : G[x])
        if (y != fa) dfs(y, x);
}

int LCA(int x, int y)
{
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 30; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
    if (x == y) return x;
    for (int i = 30; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}

void get(int x, int fa)
{
    for (auto y : G[x])
        if (y != fa)
        {
            get(y, x);
            sum[x] += sum[y];
        }
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)   
        scanf("%d", a + i);
    for (int i = 1, u, v; i < n; ++i)
    {
        scanf("%d%d", &u, &v);
        G[u].push_back(v); G[v].push_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i < n; ++i)
    {
        int lca = LCA(a[i], a[i + 1]);
        sum[a[i]]++; sum[a[i + 1]]++;
        sum[lca]--; sum[f[lca][0]]--;
    }
    get(1, 0);
    for (int i = 2; i <= n; ++i) --sum[a[i]];
    for (int i = 1; i <= n; ++i) printf("%d\n", sum[i]);
    return 0;
}

[BJOI2018] 求和

Portal.

注意到 kk 的范围很小,因此可以将不同的 kk 分开来做,那么这道题就成了树上前缀和的模板题。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;
typedef long long i64;
const int MOD = 998244353;

int n, m, lg[300005];
int f[300005][25], dep[300005];
i64 w[55][300005];
vector <int> G[300005];

void dfs(int x, int fa)
{
	if (fa == 0) dep[x] = 0;
	else dep[x] = dep[fa] + 1;
	f[x][0] = fa; w[1][x] = dep[x];
	for (int i = 2; i <= 50; ++i) w[i][x] = w[i - 1][x] * dep[x] % MOD;
	for (int i = 1; i <= lg[n]; ++i)
		f[x][i] = f[f[x][i - 1]][i - 1];
	for (auto y : G[x])
		if (y != fa) dfs(y, x);
}

void df5(int x, int fa)
{
    for (auto y : G[x])
        if (y != fa)
        {
            for (int i = 1; i <= 50; ++i) w[i][y] = (w[i][y] + w[i][x]) % MOD;
            df5(y, x);
        }
}

int LCA(int x, int y)
{
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = lg[n]; i >= 0; --i)
		if (dep[f[x][i]] >= dep[y]) x = f[x][i];
	if (x == y) return x;
	for (int i = lg[n]; i >= 0; --i)
		if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}

int main(void)
{
	scanf("%d", &n);
	for (int i = 1, u, v; i < n; ++i)
	{
		scanf("%d%d", &u, &v);
		G[u].push_back(v), G[v].push_back(u);
	}
	for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
	dfs(1, 0);
    df5(1, 0);
	scanf("%d", &m);
	while (m--)
	{
		int x, y, k;
		scanf("%d%d%d", &x, &y, &k);
        int lca = LCA(x, y);
        printf("%lld\n", ((w[k][x] + w[k][y] - w[k][lca] - w[k][f[lca][0]]) % MOD + MOD) % MOD);
	}
	return 0;
}

[CF519E] A and B and Lecture Rooms

Portal.

如果两个点相同,那么所有都可以。如果这两个点之间的简单路径的中点不存在,那么没有可以的。

否则呢?假象两个点如果在同一深度上,那么只有其 LCA 的,包含着两个点的两棵子树上的点是不可以的。不在同一深度上,找到它们的中点,然后中点的子树,除了包含较深节点的那棵子树外,都是可以的。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n, m, lg[100005]; 
int f[17][100005], dep[100005], siz[100005]; 
vector<int> G[100005]; 

void dfs(int x, int fa) {
    dep[x] = dep[f[0][x] = fa] + 1; siz[x] = 1; 
    for (int i = 1; i <= lg[n]; ++i) f[i][x] = f[i - 1][f[i - 1][x]]; 
    for (int y : G[x]) if (y != fa) dfs(y, x), siz[x] += siz[y]; 
}

int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y); 
    for (int i = lg[n]; i >= 0; --i) if (dep[f[i][x]] >= dep[y]) x = f[i][x]; 
    if (x == y) return x; 
    for (int i = lg[n]; i >= 0; --i) if (f[i][x] != f[i][y]) x = f[i][x], y = f[i][y]; 
    return f[0][x]; 
}

void calc(int x, int y) {
    if (x == y) return printf("%d\n", siz[1]), void(); int l = lca(x, y); 
    if ((dep[x] + dep[y] - 2 * dep[l]) % 2) return puts("0"), void(); 
    if (dep[x] == dep[y]) {
        for (int i = lg[n]; i >= 0; --i) if (f[i][x] != f[i][y]) x = f[i][x], y = f[i][y]; 
        return printf("%d\n", n - siz[x] - siz[y]), void(); 
    } 
    int up = (dep[x] + dep[y] - 2 * dep[l]) / 2; 
    if (dep[x] < dep[y]) swap(x, y); int x0 = x; 
    for (int i = lg[n]; i >= 0; --i) {
        if (up >> i & 1) x = f[i][x]; 
        if (up - 1 >> i & 1) x0 = f[i][x0]; 
    }
    printf("%d\n", siz[x] - siz[x0]); 
}

int main(void) {
    scanf("%d", &n); 
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1; 
    for (int i = 1; i < n; ++i) {
        int u, v; scanf("%d%d", &u, &v); 
        G[u].emplace_back(v); G[v].emplace_back(u);
    } dfs(1, 0); 
    for (scanf("%d", &m); m--; ) {
        int a, b; scanf("%d%d", &a, &b); 
        calc(a, b); 
    }
    return 0;
}

[NOIP2015 提高组] 运输计划

Portal.

像这种要求最大值最小的问题显然想到二分答案,然后使用树上边差分进行 check 即可,就找一个虫洞在所有不满足条件的运输计划上,取最大的一个来判断是否可行。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>

using namespace std;

struct edge {
    int v, d;
    edge(int v = 0, int d = 0) :
        v(v), d(d) {}
};

int n, m, lg[300005];
int dep[300005], f[300005][20], dis[300005];
int s[300005], t[300005], lca[300005], dist[300005], C[300005];
int maxlen = 0, sum = 0, val[300005], ans;
vector <edge> G[300005];

void dfs(int x, int fa) {
    dep[x] = dep[fa] + 1, f[x][0] = fa;
    for (int i = 1; i <= lg[n]; ++i) f[x][i] = f[f[x][i - 1]][i - 1];
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i].v, w = G[x][i].d;
        if (y == fa) continue;
        dis[y] = dis[x] + w, val[y] = w;
        dfs(y, x);
    }
}

int LCA(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = lg[n]; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
    if (x == y) return x;
    for (int i = lg[n]; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}


void get(int x, int fa) {
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i].v;
        if (y == fa) continue;
        get(y, x);
        C[x] += C[y];
    }
}

bool P(int x) { // 所有运输计划都不超过 x
    int cnt = 0, maxdis = 0;
    memset(C, 0, sizeof(C));
    for (int i = 1; i <= m; ++i)
        if (dist[i] > x) {
            C[s[i]]++, C[t[i]]++, C[lca[i]] -= 2;
            ++cnt;
        }
    if (cnt == 0) return true;
    get(1, 0);
    for (int i = 1; i <= n; ++i)
        if (C[i] == cnt) maxdis = max(maxdis, val[i]);
    if (maxlen - maxdis <= x) return true;
    return false;
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1, u, v, d; i < n; ++i) {
        scanf("%d%d%d", &u, &v, &d);
        G[u].push_back(edge(v, d)), G[v].push_back(edge(u, d));
    }
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    dfs(1, 0);
    for (int i = 1; i <= m; ++i) {
        scanf("%d%d", s + i, t + i);
        lca[i] = LCA(s[i], t[i]);
        dist[i] = dis[s[i]] + dis[t[i]] - 2 * dis[lca[i]];
        sum += dist[i];
        maxlen = max(maxlen, dist[i]);
    }
    int L = -1, R = 300000005;
    while (L + 1 != R) {
        int mid = L + R >> 1;
        if (P(mid)) R = mid;
        else L = mid;
    }
    printf("%d\n", R);
    return 0;
}

[NOIP2013 提高组] 货车运输

Portal.

直觉告诉我们,走的路应该在最大生成树上。那我们先求出生成树,然后预处理 LCA,要记 w[x][i] 代表 xx 向上蹦 2i2^i 次所遇到的最小边权,然后直接做就行了。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;
const int INF = 1e9;

struct edge {
    int u, v, d;
    edge(int u = 0, int v = 0, int d = 0) : u(u), v(v), d(d) {}
    bool operator < (const edge &a) const { return d > a.d; }
}a[50005];

int n, m, q;
int bin[10005];
int find(int x) {
    if (bin[x] == x) return x;
    return bin[x] = find(bin[x]); 
}

vector <edge> G[10005];
inline void addedge(int u, int v, int d) { G[u].push_back(edge(u, v, d)); }

void Kruskal(void) {
    sort(a + 1, a + m + 1);
    for (int i = 1; i <= n; ++i) bin[i] = i;
    for (int i = 1; i <= m; ++i) {
        int x = find(a[i].u), y = find(a[i].v);
        if (x == y) continue;
        bin[x] = y;
        addedge(a[i].u, a[i].v, a[i].d);
        addedge(a[i].v, a[i].u, a[i].d);
    }
}

bool v[10005];
int dep[10005], lg[10005];
int f[10005][15], w[10005][15];
void dfs(int x, int fa) {
    v[x] = true;
    dep[x] = dep[fa] + 1;
    f[x][0] = fa;
    for (int i = 1; i <= lg[n]; ++i)
        f[x][i] = f[f[x][i - 1]][i - 1], w[x][i] = min(w[x][i - 1], w[f[x][i - 1]][i - 1]);
    for (int i = 0; i < G[x].size(); ++i)
        if (G[x][i].v != fa) {
            w[G[x][i].v][0] = G[x][i].d;
            dfs(G[x][i].v, x);
        }
}

int LCA(int x, int y) {
    if (find(x) != find(y)) return -1;
    if (dep[x] < dep[y]) swap(x, y);
    int ans = INF;
    for (int i = lg[n]; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y]) ans = min(ans, w[x][i]), x = f[x][i];
    if (x == y) return ans;
    for (int i = lg[n]; i >= 0; --i)
        if (f[x][i] != f[y][i]) ans = min({ans, w[x][i], w[y][i]}), x = f[x][i], y = f[y][i];
    return min({ans, w[x][0], w[y][0]});
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= m; ++i)
        scanf("%d%d%d", &a[i].u, &a[i].v, &a[i].d);
    Kruskal();
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    for (int i = 1; i <= n; ++i)
        if (!v[i]) dfs(i, 0);
    scanf("%d", &q);
    while (q--) {
        int x, y;
        scanf("%d%d", &x, &y);
        printf("%d\n", LCA(x, y));
    }
    return 0;
}

[NOIP2016 提高组] 天天爱跑步

Portal.

一棵包含 nn 个点的树,有 mm 个玩家,第 ii 个玩家的起点为 sis_i,终点为 tit_i。所有玩家在第 00 秒同时从自己的起点出发,以每秒跑一条边的速度向着终点跑去。

每个结点上都放置了一个观察员。在结点 jj 的观察员会选择在第 wjw_j 秒观察玩家,一个玩家能被这个观察员观察到当且仅当该玩家恰好在第 wjw_j 秒也正好到达了结点 jj 。问每个观察员会观察到多少人?

一个玩家能够被观察员 xx 观察到,当且仅当:

  • 这个观察员在 [s,LCA(s,t)][s,LCA(s,t)] 上,那么需要满足 dep[s]dep[x]=w[x]dep[s]-dep[x]=w[x],相当于 dep[s]=dep[x]+w[x]dep[s]=dep[x]+w[x]
  • 这个观察员在 (LCA(s,t),t](LCA(s,t),t] 上,那么需要满足 dep[s]+dep[x]2×dep[LCA(s,t)]=w[x]dep[s]+dep[x]-2\times dep[LCA(s,t)]=w[x],相当于 dep[s]2×dep[LCA(s,t)]=w[x]dep[x]dep[s]-2\times dep[LCA(s,t)]=w[x]-dep[x]

这个模型就很清晰了,因为右面的信息只和观察员的位置有关。因此对于每一个玩家,我们都使用树上差分的思路。将玩家拆分成两个(一个从 sslcalca,一个从 ttlcalca),然后把差分维护操作,给每一个节点都加上一种“物品”。放置相应节点编号的到 STL vector 中。然后进行 DFS。

建立数组 c1,c2c1,c2 来维护树上前缀和(因为有两种,分开维护比较方便)。使用 DFS 来统计树上前缀和,刚才已经把差分的操作放置到了 STL vector 中,我们只需要将这些操作执行,统计前后 cc 的变化就可以得到当前节点的答案了。

注意编号可能是负的,需要平移。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>

using namespace std;
const int N = 300005;

int read(void) {
    int x = 0, c = getchar_unlocked();
    while (!isdigit(c)) c = getchar_unlocked();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar_unlocked();
    return x;
}

int n, m, f[N][20];
int dep[N], lg[N];
int w[N], ans[N], c1[N * 2], c2[N * 2];
vector <int> G[N];
vector <int> a1[N], b1[N], a2[N], b2[N];

void dfs(int x, int fa) {
    f[x][0] = fa; dep[x] = dep[fa] + 1;
    for (int i = 1; i <= lg[n]; ++i)
        f[x][i] = f[f[x][i - 1]][i - 1];
    for (auto y : G[x])
        if (y != fa) dfs(y, x);
}

int LCA(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = lg[n]; i >= 0; --i)
        if (dep[f[x][i]] >= dep[y]) x = f[x][i];
    if (x == y) return x;
    for (int i = lg[n]; i >= 0; --i)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
    return f[x][0];
}

void dfs2(int x, int fa) {
    int val1 = c1[w[x] + dep[x]], val2 = c2[w[x] - dep[x] + n];
    for (auto y : G[x]) {
        if (y == fa) continue;
        dfs2(y, x);
    }
    for (auto i : a1[x]) c1[i]++;
    for (auto i : b1[x]) c1[i]--;
    for (auto i : a2[x]) c2[i + n]++;
    for (auto i : b2[x]) c2[i + n]--;
    ans[x] = c1[w[x] + dep[x]] - val1 + c2[w[x] - dep[x] + n] - val2;
}

int main(void) {
    n = read(), m = read();
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
    for (int i = 1, u, v; i < n; ++i) {
        u = read(), v = read();
        G[u].emplace_back(v);
        G[v].emplace_back(u);
    }
    dfs(1, 0);
    for (int i = 1; i <= n; ++i) w[i] = read();
    while (m--) {
        int s = read(), t = read();
        int lca = LCA(s, t);
        a1[s].emplace_back(dep[s]);
        b1[f[lca][0]].emplace_back(dep[s]);
        a2[t].emplace_back(dep[s] - 2 * dep[lca]);
        b2[lca].emplace_back(dep[s] - 2 * dep[lca]);
    }
    dfs2(1, 0);
    for (int i = 1; i <= n; ++i) printf("%d ", ans[i]);
    putchar('\n');
    return 0;
}

DFS 序列

看起来重链剖分可以代替 DFS 序列,只是慢了一点。实际上不是。本质上重链剖分也是一种特殊的 DFS 序列,而重链剖分转为链式结构来维护也不能解决所有问题(比如树上倍增就很有用)。DFS 序列自身也有一条美妙的性质:每一个节点恰好出现两次,中间是子树。这里仅举一个例子,足够说明问题。

[CF176E] Archaeology

Portal.

有一棵 nn 个点的带权树,每个点都是黑色或白色,最初所有点都是白色的。有 qq 个询问:

  • 把点 xx 从白色变成黑色。
  • 把点 xx 从黑色变成白色。
  • 查询黑点的导出子树 ((用最少的边把所有的黑点连通起来的树)) 的总边权和,实际上就是虚树大小
    保证 1n,q105,1xn1 \leq n, q \leq 10^5, 1 \leq x \leq n

维护黑点的集合,并将它们按照 dfn 从小到大排序,设排序后的序列为 ff,那么答案就是:

12(d(f1,f2)++d(fn1,fn)+d(fn,f1))\frac{1}{2}(d(f_1,f_2)+\cdots+d(f_{n-1},f_n)+d(f_n,f_1))

DFS 序列的性质就可以证明:它相当于是遍历了这些点两遍。知道了这一点之后就可以直接维护,插入和删除的时候计算对答案的贡献即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <cstring>
#include <set>
#define pii pair<int, int>

using namespace std;
typedef long long i64;

int n, q;
set<int> S;
int dfn[100005], son[100005], top[100005], num = 0;
int f[100005], siz[100005], dep[100005], idx[100005];
i64 dis[100005], ans = 0;
vector<pii> G[100005];

void dfs1(int x, int fa) {
    siz[x] = 1; f[x] = fa; dep[x] = dep[fa] + 1;
    int max_part = -1;
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i].first, w = G[x][i].second;
        if (y != fa) {
            dis[y] = dis[x] + w;
            dfs1(y, x);
            siz[x] += siz[y];
            if (siz[y] > max_part) {
                son[x] = y;
                max_part = siz[y];
            }
        }
    }
}

void dfs2(int x, int topf) {
    dfn[x] = ++num; idx[num] = x; top[x] = topf;
    if (son[x] == -1) return;
    dfs2(son[x], topf);
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i].first;
        if (y != f[x] && y != son[x]) dfs2(y, y);
    }
}

int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = f[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}
i64 F(int x, int y) { return dis[x] + dis[y] - 2 * dis[LCA(x, y)]; }

i64 query(set<int>::iterator it) {
    auto pre = it, nxt = it;
    pre = (it == S.begin() ? S.end() : it); --pre;
    ++nxt; if (nxt == S.end()) nxt = S.begin();
    int l = idx[*pre], r = idx[*nxt], x = idx[*it];
    return F(l, x) + F(x, r) - F(l, r);
}

int main(void)
{
    memset(son, -1, sizeof(son));
    scanf("%d", &n);
    for (int i = 1; i < n; ++i) {
        int u, v, d;
        scanf("%d%d%d", &u, &v, &d);
        G[u].push_back({v, d});
        G[v].push_back({u, d});
    }
    dfs1(1, 0);
    dfs2(1, 1);
    scanf("%d", &q);
    while (q--) {
        char op = getchar(); int x;
        while (op != '+' && op != '-' && op != '?') op = getchar();
        if (op == '+') {
            scanf("%d", &x);
            if (S.find(dfn[x]) != S.end()) continue;
            auto it = S.insert(dfn[x]).first;
            if (S.size() > 2) ans += query(it);
            else if (S.size() == 2) {
                ++it;
                auto ot = (it == S.end() ? S.begin() : it);
                ans = F(idx[*ot], x) * 2;
            }
        } else if (op == '-') {
            scanf("%d", &x);
            auto it = S.find(dfn[x]);
            if (it == S.end()) continue;
            if (S.size() == 2) ans = 0;
            else if (S.size() > 2) ans -= query(it);
            S.erase(it);
        } else printf("%lld\n", ans / 2);
    }
    return 0;
}

[CF1149C] Tree Generator™

Portal.

考虑括号序的任意一个子序列代表了什么?树上路径的移动过程!由于直径必能表示成一个移动过程,因此设 f(l,r)f(l,r) 代表括号序 [l,r][l,r] 中删掉匹配括号之后的长度,最大的 f(l,r)f(l,r) 就是答案。

然鹅这东西看起来不是很好维护,但不是不能推出一点东西:设 ( = 1, ) = -1。剩余的括号一定形如 )))(((,设删完之后设 xx 个有括号,yy 个左括号,则 f(l,r)=x+y=max{s(k+1,r)s(l,k)}f(l,r)=x+y=\max\{s(k+1,r)-s(l,k)\}

这个东西就可以用小白逛公园线段树来维护了。

查看代码
#include <bits/stdc++.h>
using namespace std; 

int n, m, a[200005]; 
char s[200005]; 

struct Node {  
	int s, lmx, rmn, lrans, lans, rans, ans; 
	friend Node operator+ (const Node &a, const Node &b) {
		Node c; c.s = a.s + b.s; 
		c.lmx = max(a.lmx, a.s + b.lmx); 
		c.rmn = min(b.rmn, a.rmn + b.s); 
		c.lans = max({a.lans, a.lrans + b.lmx, b.lans - a.s}); 
		c.rans = max({b.rans, b.lrans - a.rmn, a.rans + b.s}); 
		c.lrans = max(a.lrans + b.s, b.lrans - a.s); 
		c.ans = max({a.ans, b.ans, a.rans + b.lmx, b.lans - a.rmn}); 
		return c; 
	}
} T[800005]; 

void build(int o, int l, int r) {
	if (l == r) {
		T[o].s = a[l]; T[o].lmx = max(a[l], 0); T[o].rmn = min(a[l], 0); 
		T[o].lans = T[o].rans = T[o].lrans = T[o].ans = 1; 
		return; 
	}
	int mid = l + r >> 1; build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
	T[o] = T[o << 1] + T[o << 1 | 1]; 
}
void update(int o, int l, int r, int x, int k) {
	if (l == r) {
		T[o].s = k; T[o].lmx = max(k, 0); T[o].rmn = min(k, 0); 
		return; 
	}
	int mid = l + r >> 1; 
	if (x <= mid) update(o << 1, l, mid, x, k); 
	else update(o << 1 | 1, mid + 1, r, x, k); 
	T[o] = T[o << 1] + T[o << 1 | 1]; 
}

int main(void) {
	scanf("%d%d%s", &n, &m, s + 1); n = n * 2 - 2; 
	for (int i = 1; i <= n; ++i) a[i] = (s[i] == '(' ? 1 : -1);  
	build(1, 1, n); printf("%d\n", T[1].ans); 
	while (m--) {
		int x, y; scanf("%d%d", &x, &y); swap(a[x], a[y]); 
		update(1, 1, n, x, a[x]); update(1, 1, n, y, a[y]); 
		printf("%d\n", T[1].ans); 
	}
	return 0; 
}

其它树上问题

这里是一些树上杂题。

【XR-3】核心城市

Portal.

考虑从叶子节点开始向中间推进。当叶子节点的数量达到 nkn-k 时就应停止。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <queue>

using namespace std;

int n, k, dep[100005], deg[100005];
bool vis[100005];
vector<int> G[100005];
queue<int> Q;

int main(void)
{
    scanf("%d%d", &n, &k); k = n - k;
    for (int i = 1; i < n; ++i) {
        int u, v;
        scanf("%d%d", &u, &v);
        G[u].emplace_back(v); deg[u] += 1;
        G[v].emplace_back(u); deg[v] += 1;
    }
    int ans = 0;
    for (int i = 1; i <= n; ++i)
        if (G[i].size() == 1 && k >= 1) Q.push(i), vis[i] = true, --k, ans = dep[i] = 1, --deg[i];
    if (k >= 1) {
        while (!Q.empty()) {
            int x = Q.front(); Q.pop();
            for (int i = 0; i < G[x].size(); ++i) {
                int y = G[x][i];
                deg[y] -= 1;
                if (!vis[y] && deg[y] == 1) {
                    dep[y] = dep[x] + 1;
                    ans = max(ans, dep[y]);
                    vis[y] = true;
                    Q.push(y);
                    --k;
                    if (k < 1) {
                        printf("%d\n", ans);
                        return 0;
                    }
                } 
            }
        }
    }
    printf("%d\n", ans);
    return 0;
}

表达式树 | [CSP-J2020] 表达式

Portal.

先考虑不带修怎么做。扫描 ss,开一个栈,建立表达式树,遇到数就入栈,运算符就弹栈进行运算然后再入栈。最后 dfs 一次表达式树就可以计算出答案。

当修改时,要么答案会变,要么答案不变。记录一个数组 cc,代表答案是否会不变。如果与运算的时候一个数为零,那么改变另一个数没有意义。或运算同理。这种关系可以再一次进行 dfs,通过或运算向儿子传递。

查看代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <stack>

using namespace std;

int n, q, a[2000005], ck = 0;
int son[2000005][2];
bool flag[2000005], c[2000005]; // c 记录当前节点的值改变是否有用
char s[1000010];

bool dfs(int x, bool f) 
{
    f ^= flag[x];
    a[x] ^= f;
    if (x <= n) return a[x];
    bool p = dfs(son[x][0], f), q = dfs(son[x][1], f);
    if (a[x] == 2) {
        if (p == 0) c[son[x][1]] = 1;
        if (q == 0) c[son[x][0]] = 1;
        return p & q;
    } else {
        if (p == 1) c[son[x][1]] = 1;
        if (q == 1) c[son[x][0]] = 1;
        return p | q;
    }
}

void dfs2(int x)
{
    if (x <= n) return;
    c[son[x][0]] |= c[x];
    c[son[x][1]] |= c[x];
    dfs2(son[x][0]);
    dfs2(son[x][1]);
}

int main(void)
{
    fgets(s, 1000002, stdin);
    scanf("%d", &n); ck = n;
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    stack<int> b;
    for (int i = 0; s[i]; i += 2) {
        if (s[i] == 'x') {
            int x = 0;
            ++i;
            while (s[i] != ' ') {
                x = x * 10 + s[i] - '0';
                ++i;
            }
            --i;
            b.push(x);
        } else if (s[i] == '!') flag[b.top()] ^= 1; else {
            int x = b.top(); b.pop();
            int y = b.top(); b.pop();
            b.push(++ck);
            if (s[i] == '&') a[ck] = 2;
            else a[ck] = 3;
            son[ck][0] = x; son[ck][1] = y;
        }
    }
    bool ans = dfs(ck, 0);
    dfs2(ck);
    scanf("%d", &q);
    while (q--) {
        int x;
        scanf("%d", &x);
        printf("%d\n", c[x] ? ans : !ans);
    }
    return 0;
}

[SDOI2011] 消防

Portal.

要选的话,往直径上选比较好,因为即使直径选不满而选到了非直径上的边,那么原来直径上的点也使答案距离更大。

因此随便搞一条直径,然后枚举直径上的一个起点,计算一个最远的终点(距离不超过 ss),选这一段。只考虑直径上的点,满足最远点距离最小的这一段即为答案。因为如果不是答案(直径外点到这段的距离更大),那么这条直径必然是假的。

这里偷懒了,直径上的查找没有使用单调队列。

查看代码
#include <bits/stdc++.h>
#define pii pair<int, int>
using namespace std; 

int n, s, X, Y, ison[300005]; // (X, Y) 为直径
int d[300005], fr[300005]; // X -> Y 从哪里来
int ans = 2e9; 
vector<pii> G[300005]; 

void dfs(int x, int fa) {
    fr[x] = fa; 
    for (auto [y, w] : G[x]) if (y != fa) d[y] = d[x] + w, dfs(y, x); 
}
void workA(void) { // 标记直径上的点, d[x] 为 x 到 Y 的距离
    memset(d, 0, sizeof d); int u = Y; ison[Y] = 1; 
    while (u != X) {
        for (auto [v, w] : G[u]) if (v == fr[u]) ison[v] = 1, d[v] = d[u] + w; 
        u = fr[u]; 
    }
}

int find(int x, int len) { // x 向 X 走 len 最远走到哪里
	for (auto [y, w] : G[x]) if (y == fr[x] && len >= w) return find(y, len - w); 
	return x; 
}
void workB(void) { // 枚举直径上的起点
	int u = Y; 
	while (u != X) {
		ans = min(ans, max(d[u], d[X] - d[find(u, s)])); 
		u = fr[u]; 
	}
	ans = min(ans, max(d[X], d[X] - d[find(X, s)])); 
}
void query(int x, int fa) {
    for (auto [y, w] : G[x]) if (!ison[y] && y != fa) 
		d[y] = d[x] + w, query(y, x); 
}
void workC(void) {
	memset(d, 0, sizeof d); 
	int u = Y; 
	while (u != X) query(u, 0), u = fr[u]; 
	query(X, 0); 
}

int main(void) {
	scanf("%d%d", &n, &s); 
	for (int i = 1; i <= n; ++i) {
		int u, v, w; scanf("%d%d%d", &u, &v, &w); 
		G[u].emplace_back(v, w); G[v].emplace_back(u, w); 
	} dfs(1, 0); 
	for (int i = 1; i <= n; ++i) if (d[i] > d[X]) X = i; 
	d[X] = 0; dfs(X, 0); 
	for (int i = 1; i <= n; ++i) if (d[i] > d[Y]) Y = i; 
	workA(); workB(); workC(); 
	for (int i = 1; i <= n; ++i) ans = max(ans, d[i]); 
	return !printf("%d\n", ans); 
}

[CF1707C] DFS Trees

Portal.

题中的求解方式是求搜索树,也就是说,以一个节点为根的时候,所有不在 MST 上的边必须都是返祖边。

对于一条边连接的两个点,两子树内的点都 +1,树上差分维护即可。

查看代码
#include <bits/stdc++.h>
using namespace std; 

int n, m, lg[100005]; 
struct Edge {
	int u, v; 
} e[200005]; 
bool vis[200005]; 
int bin[100005], f[18][100005], dep[100005], s[100005]; 
vector<int> G[100005]; 
int find(int x) { return bin[x] == x ? x : bin[x] = find(bin[x]); }

int LCA(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = lg[n]; i >= 0; --i) if (dep[f[i][x]] >= dep[y]) x = f[i][x]; 
	if (x == y) return x; 
	for (int i = lg[n]; i >= 0; --i) if (f[i][x] != f[i][y]) x = f[i][x], y = f[i][y]; 
	return f[0][x];
}
void dfs(int x, int fa) {
	f[0][x] = fa; dep[x] = dep[fa] + 1; 
	for (int i = 1; i <= lg[n]; ++i) f[i][x] = f[i - 1][f[i - 1][x]]; 
	for (int y : G[x]) if (y != fa) dfs(y, x); 
}
void dfs2(int x, int fa) {
	s[x] += s[fa]; 
	for (int y : G[x]) if (y != fa) dfs2(y, x); 
}

int main(void) {
	scanf("%d%d", &n, &m); 
	for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1; 
	for (int i = 1; i <= m; ++i) scanf("%d%d", &e[i].u, &e[i].v); 
	for (int i = 1; i <= n; ++i) bin[i] = i; 
	for (int i = 1; i <= m; ++i) {
		int u = e[i].u, v = e[i].v; 
		if (find(u) == find(v)) continue; 
		G[u].emplace_back(v); G[v].emplace_back(u); 
		vis[i] = 1; bin[find(u)] = find(v); 
	} dfs(1, 0); 
	for (int i = 1; i <= m; ++i) if (!vis[i]) {
		int u = e[i].u, v = e[i].v, l = LCA(u, v); 
		if (dep[u] > dep[v]) swap(u, v); 
		if (l == u) { // (u, v) 路径上不行
			++s[1], ++s[v]; 
			int p = v; 
			for (int j = lg[n]; j >= 0; --j) if (dep[f[j][p]] > dep[u]) p = f[j][p]; 
			--s[p]; 
		} else ++s[u], ++s[v]; 
	} dfs2(1, 0); 
	for (int i = 1; i <= n; ++i) putchar(s[i] == m - (n - 1) ? '1' : '0'); 
	return putchar('\n'), 0; 
}

重链剖分

看起来树链剖分就是个板子,只是将链上的数据结构放到树上了。实际上不是,因为树自身也有许多性质。

[NOI2021] 轻重边

Portal.

有一棵 nn 个结点的树,树上的每一条边可能是轻边或者重边。接下来你需要对树进行 mm 次操作,在所有操作开始前,树上所有边都是轻边。操作有以下两种:

  1. 给定两个点 aabb,首先对于 aabb 路径上的所有点 xx(包含 aabb),你要将与 xx 相连的所有边变为轻边。然后再将 aabb 路径上包含的所有边变为重边。
  2. 给定两个点 aabb,你需要计算当前 aabb 的路径上一共包含多少条重边。

肯定是树剖(LCT 也可,但是笔者不会 LCT),但是如何简单维护呢?

考虑使用颜色来维护。初始时每一个点的颜色都不同;每一次修改的时候,我们都将 (a,b)(a,b) 间染上一个新的颜色,然后重边的判定法则就变成了:连接的两个端点的颜色相同(想一想看是不是这样)。

那么实现一个可以统计颜色相同的相邻对的线段树即可。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n, m;
vector <int> G[100005];
int dep[100005], siz[100005], f[100005];
int dfn[100005], top[100005], son[100005], num = 0;

void dfs1(int x, int fa) {
    dep[x] = dep[fa] + 1; siz[x] = 1; f[x] = fa;
    int maxx = -1;
    for (auto y : G[x]) 
        if (y != fa) {
            dfs1(y, x);
            siz[x] += siz[y];
            if (siz[y] > maxx) {
                maxx = siz[y];
                son[x] = y;
            }
        }
}

void dfs2(int x, int topf) {
    dfn[x] = ++num; top[x] = topf;
    if (son[x] == -1) return;
    dfs2(son[x], topf);
    for (auto y : G[x])
        if (y != f[x] && y != son[x]) dfs2(y, y);
}

// Segment Tree
struct Node {
    int lc, rc, tag;
    int cnt;
    Node(int lc = 0, int rc = 0, int cnt = 0, int tag = 0) :
        lc(lc), rc(rc), tag(tag), cnt(cnt) {}
} T[400005];
inline Node hb(const Node &a, const Node &b) {
    Node c(a.lc, b.rc, a.cnt + b.cnt + (a.rc == b.lc));
    return c;
}
inline void pushdown(int o, int l, int r) {
    if (!T[o].tag) return;
    int mid = l + r >> 1;
    T[o << 1] = Node(T[o].tag, T[o].tag, mid - l, T[o].tag);
    T[o << 1 | 1] = Node(T[o].tag, T[o].tag, r - mid - 1, T[o].tag);
    T[o].tag = 0;   
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) {
        T[o] = Node(k, k, r - l, k);
        return;
    }
    pushdown(o, l, r); int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k);
    T[o] = hb(T[o << 1], T[o << 1 | 1]);
}
Node query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    pushdown(o, l, r); int mid = l + r >> 1;
    if (y <= mid) return query(o << 1, l, mid, x, y);
    if (mid < x) return query(o << 1 | 1, mid + 1, r, x, y);
    return hb(query(o << 1, l, mid, x, y), query(o << 1 | 1, mid + 1, r, x, y));
}
int query(int x, int y) {
    bool flag = 0;
    Node ans1, ans2, tmp;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y), flag = !flag;
        tmp = query(1, 1, n, dfn[top[x]], dfn[x]);
        if (flag) ans2 = Node(tmp.lc, ans2.rc, ans2.cnt + tmp.cnt + (ans2.lc == tmp.rc));
        else ans1 = Node(ans1.lc, tmp.lc, ans1.cnt + tmp.cnt + (ans1.rc == tmp.rc));
        x = f[top[x]];
    }
    if (dep[x] < dep[y]) swap(x, y), flag = !flag;
    tmp = query(1, 1, n, dfn[y], dfn[x]);
    if (flag) ans2 = Node(tmp.lc, ans2.rc, ans2.cnt + tmp.cnt + (ans2.lc == tmp.rc));
    else ans1 = Node(ans1.lc, tmp.lc, ans1.cnt + tmp.cnt + (ans1.rc == tmp.rc));
    return ans1.cnt + ans2.cnt + (ans1.rc == ans2.lc);
}

int main(void) {
    int TT; scanf("%d", &TT);
    while (TT--) {
        scanf("%d%d", &n, &m);
        num = 0;
        memset(son, -1, sizeof(son));
        memset(T, 0, sizeof(T));
        for (int i = 1; i <= n; ++i) G[i].clear();
        for (int i = 1; i < n; ++i) {
            int u, v;
            scanf("%d%d", &u, &v);
            G[u].emplace_back(v);
            G[v].emplace_back(u);
        }
        for (int i = 1; i <= n; ++i) update(1, 1, n, i, i, -i);
        dfs1(1, 0);
        dfs2(1, 1);
        for (int i = 1; i <= m; ++i) {
            int op, x, y;
            scanf("%d%d%d", &op, &x, &y);
            if (op == 1) {
                while (top[x] != top[y]) {
                    if (dep[top[x]] < dep[top[y]]) swap(x, y);
                    update(1, 1, n, dfn[top[x]], dfn[x], i);
                    x = f[top[x]];
                }
                if (dep[x] < dep[y]) swap(x, y);
                update(1, 1, n, dfn[y], dfn[x], i);
            } else printf("%d\n", query(x, y));
        }
    }
    return 0;
}

[SCOI2015] 情报传递

Portal.

nn 名情报员形成树形结构,每天会派发以下两种任务中的一个任务:

  • 指派 TT 号情报员搜集情报;
  • 将一条情报从 XX 号情报员经最短路径传递给 YY 号情报员。

情报员最初处于潜伏阶段,此时所有情报员的危险值为 00;一旦某个情报员开始搜集情报,他的危险值就会持续增加,每天增加 11 点(开始搜集情报的当天危险值仍为 00,第 22 天为 11)。

每条情报都有一个风险控制值 CC。参与传递这条情报的危险值大于 CC 的情报员将对该条情报构成威胁。问对于每个传递情报任务,参与传递的情报员有多少个,其中对该条情报构成威胁的情报员有多少个。

n2×105,Q2×105,0<Pi,CiN,1Ti,Xi,Yinn\le 2\times 10^5,Q\le 2\times 10^5,0<P_i,C_i\le N,1\le T_i,X_i,Y_i\le n

考虑这个限制条件是什么意思。假定一个人开始作死的时间为 tt,一条情报传递任务的时间为 ii,那么需要满足 t+c>it+c>i,也就是 tic1t\ge i-c-1

那么,离线,按照 ic1i-c-1 排序,然后依次将冒险者添加进来(单点修改),查询就是链上距离和链上查询。

查看代码
#include <bits/stdc++.h>
using namespace std;

struct Question {
    int x, y, c, id;
    Question(int x = 0, int y = 0, int c = 0, int id = 0) :
        x(x), y(y), c(c), id(id) {}
    bool operator < (const Question &a) const {
        return c < a.c;
    }
} Q[200005];
struct Operation {
    int x, t;
    Operation(int x = 0, int t = 0) :
        x(x), t(t) {}
} A[200005];
int tot = 0, tot2 = 0;

int n, q, root;
bool vis[200005];
int f[200005], dep[200005], siz[200005], dis[200005];
int son[200005], top[200005], dfn[200005], num = 0;
vector<int> G[200005];
int ans[200005];

void dfs1(int x) {
    dep[x] = dep[f[x]] + 1; siz[x] = 1;
    int max_part = -1;
    for (int y : G[x]) {
        dfs1(y);
        siz[x] += siz[y];
        if (siz[y] > max_part) {
            max_part = siz[y];
            son[x] = y;
        }
    }
}

void dfs2(int x, int topf) {
    dfn[x] = ++num; top[x] = topf;
    if (son[x] == -1) return;
    dfs2(son[x], topf);
    for (int y : G[x])
        if (y != son[x]) dfs2(y, y);
}

int LCA(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        x = f[top[x]];
    }
    return dep[x] < dep[y] ? x : y;
}

int T[800005];
void update(int o, int l, int r, int x) {
    if (l == r) return T[o] += 1, void();
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x);
    else update(o << 1 | 1, mid + 1, r, x);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
int query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1, res = 0;
    if (x <= mid) res += query(o << 1, l, mid, x, y);
    if (mid < y) res += query(o << 1 | 1, mid + 1, r, x, y);
    return res;
}

int main(void) {
    memset(son, -1, sizeof(son));
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) {
        scanf("%d", f + i);
        if (f[i] == 0) root = i;
        else G[f[i]].emplace_back(i);
    }
    dfs1(root);
    dfs2(root, root);
    scanf("%d", &q);
    for (int i = 1; i <= q; ++i) {
        int op, x, y, c;
        scanf("%d%d", &op, &x);
        if (op == 1) {
            scanf("%d%d", &y, &c); ++tot;
            Q[tot] = Question(x, y, i - c - 1, tot);
        } else A[++tot2] = Operation(x, i);
    }
    sort(Q + 1, Q + tot + 1);
    for (int i = 1, p = 1; i <= tot; ++i) {
        while (p <= tot2 && A[p].t <= Q[i].c) {
            if (!vis[A[p].x]) update(1, 1, n, dfn[A[p].x]);
            vis[A[p].x] = true; ++p;
        }
        int x = Q[i].x, y = Q[i].y, res = 0;
        dis[Q[i].id] = dep[x] + dep[y] - 2 * dep[LCA(x, y)] + 1;
        while (top[x] != top[y]) {
            if (dep[top[x]] < dep[top[y]]) swap(x, y);
            res += query(1, 1, n, dfn[top[x]], dfn[x]);
            x = f[top[x]];
        }
        if (dep[x] < dep[y]) swap(x, y);
        res += query(1, 1, n, dfn[y], dfn[x]);
        ans[Q[i].id] = res;
    }
    for (int i = 1; i <= tot; ++i)
        printf("%d %d\n", dis[i], ans[i]);
    return 0;
}

[BJOI2014] 大融合

Portal.

给定一个无向图,动态加边,查询一条边的负载(经过这条边的简单路径的数量),保证任意时刻图都是一片森林。

解决动态图的常规手段依然是动态树,但是此题没有删边操作,也没有强制在线,所以可以考虑使用树剖来求解

将询问全部读入,然后建立出最终的森林,并对每一棵树都进行树链剖分。建立一个树状数组维护树上子树的大小的前缀和,利用差分的方式进行修改(因为子树大小只需要单点查询,写线段树有点小题大做了)。初始时要将所有所有点的子树大小都初始化为 11,因为它们都是孤立的。同时维护一个并查集,用于查询一个节点所在的集合。

扫描每一个操作:

  • 连边。设要将 yy 连接到 xxxxyy 的祖先,那么并查集进行合并操作,现在在并查集中,find(y) 的结果一定是 find(x) 的结果。现在需要维护子树大小。yy 所对应的子树大小不变,xx 以及 xx 的父亲所对应的子树大小需要加上 siz[y]siz[y]。但是由于树还没有建完,所以从 find(x) 的父亲开始,需要减去 siz[y]siz[y]
  • 查询。当前 yy 的子树大小是一部分,整个集合中剩下的是一部分,两者的乘积就是答案。
查看代码
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)
using namespace std;

struct Question {
    char op;
    int x, y;
} a[100005];

int n, q;
vector<int> G[100005];
int dep[100005], siz[100005], f[100005];
int dfn[100005], top[100005], son[100005], num = 0;

struct UnionFind_Set {
    int f[100005];
    void init(void) { for (int i = 1; i <= n; ++i) f[i] = i; }
    int find(int x) {
        if (f[x] == x) return x;
        return f[x] = find(f[x]);
    }
} S;

void dfs1(int x, int fa) {
    dep[x] = dep[fa] + 1; siz[x] = 1; f[x] = fa;
    int max_part = -1;
    for (int y : G[x])
        if (y != fa) {
            dfs1(y, x);
            siz[x] += siz[y];
            if (siz[y] > max_part) {
                son[x] = y;
                max_part = siz[y];
            }
        }
}
void dfs2(int x, int topf) {
    dfn[x] = ++num; top[x] = topf;
    if (son[x] == -1) return;
    dfs2(son[x], topf);
    for (int y : G[x])
        if (y != f[x] && y != son[x]) dfs2(y, y);
}

int C[100005];
void add(int x, int k) {
    if (x == 0) return;
    for (; x <= n; x += lowbit(x))
        C[x] += k;
}
int query(int x) {
    int res = 0;
    for (; x >= 1; x -= lowbit(x)) res += C[x];
    return res;
}

void update(int x, int k) {
    while (x >= 1) {
        // 修改:top[x] -> x
        add(dfn[top[x]], k);
        add(dfn[x] + 1, -k); // 因为是差分,减的地方要加上 1
        x = f[top[x]];
    }
}

int main(void) {
    scanf("%d%d", &n, &q);
    for (int i = 1; i <= q; ++i) {
        a[i].op = getchar();
        while (a[i].op != 'A' && a[i].op != 'Q') a[i].op = getchar();
        scanf("%d %d", &a[i].x, &a[i].y);
        if (a[i].op == 'A') {
            G[a[i].x].emplace_back(a[i].y);
            G[a[i].y].emplace_back(a[i].x);
        }
    }
    memset(son, -1, sizeof(son));
    for (int i = 1; i <= n; ++i)
        if (!dfn[i]) {
            dfs1(i, 0);
            dfs2(i, i);
        }
    for (int i = 1; i <= n; ++i) {
        update(i, 1); 
        update(f[i], -1);
    }
    S.init();
    for (int i = 1; i <= q; ++i) {
        int op = a[i].op, x = a[i].x, y = a[i].y;
        if (dep[x] > dep[y]) swap(x, y); // x 是 y 的祖先
        uint sy = query(dfn[y]);
        if (op == 'A') {
            S.f[y] = S.find(x);
            update(x, sy);
            update(f[S.f[x]], -sy);
        } else {
            uint s = query(dfn[S.find(x)]);
            printf("%u\n", sy * (s - sy));
        }
    }
    return 0;
}

[LNOI2014] LCA

Portal.

给出一个 nn 个节点的有根树(编号为 00n1n-1,根节点为 00)。有 mm 次询问,每次询问给出 l r zl\ r\ z,求 i=lrdep[LCA(i,z)]\sum_{i=l}^r dep[LCA(i,z)]1n,m5×1041\le n,m\le 5\times 10^4

对于一个询问来说,所有的 LCA 都在 zz 到根节点的路径上,深度代表的含义是当前点到根节点的点数。把 xx 到根的路径上的点全部 +1+1,求 yy 到根的路径的权值,就是 LCA(x,y)LCA(x,y) 的深度。于是我们采用差分的方式求解询问,将询问离线,依次将点加入系统进行维护。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int MOD = 201314;

int n, m;
int f[50005], dep[50005], siz[50005];
int dfn[50005], son[50005], top[50005], num = 0;
vector<int> G[50005];

void dfs1(int x) {
    dep[x] = dep[f[x]] + 1; siz[x] = 1;
    int max_part = -1;
    for (int y : G[x]) {
        dfs1(y); siz[x] += siz[y];
        if (siz[y] > max_part) max_part = siz[y], son[x] = y;
    }
}
void dfs2(int x, int topf) {
    dfn[x] = ++num; top[x] = topf;
    if (son[x] == -1) return; 
    dfs2(son[x], topf);
    for (int y : G[x]) 
        if (y != son[x]) dfs2(y, y);
}

int T[200005], tag[200005];
inline void pushdown(int o, int l, int r) {
    if (!tag[o]) return;
    int mid = l + r >> 1;
    tag[o << 1] += tag[o], tag[o << 1 | 1] += tag[o];
    T[o << 1] += tag[o] * (mid - l + 1), T[o << 1 | 1] += tag[o] * (r - mid);
    tag[o] = 0;
}
void update(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return tag[o] += 1, T[o] += r - l + 1, void();
    int mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) update(o << 1, l, mid, x, y);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y);
    T[o] = (T[o << 1] + T[o << 1 | 1]) % MOD;
}
int query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1, res = 0; pushdown(o, l, r);
    if (x <= mid) res = (res + query(o << 1, l, mid, x, y)) % MOD;
    if (mid < y) res = (res + query(o << 1 | 1, mid + 1, r, x, y)) % MOD;
    return res;
}

void update(int x, int y) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        update(1, 1, n, dfn[top[x]], dfn[x]);
        x = f[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    update(1, 1, n, dfn[x], dfn[y]);
}
int query(int x, int y) {
    int ans = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]]) swap(x, y);
        ans = (ans + query(1, 1, n, dfn[top[x]], dfn[x])) % MOD;
        x = f[top[x]];
    }
    if (dep[x] > dep[y]) swap(x, y);
    ans = (ans + query(1, 1, n, dfn[x], dfn[y])) % MOD;
    return ans;
}

struct Question {
    int id, r, z, flag;
    Question(int id = 0, int r = 0, int z = 0, int flag = 0) :
        id(id), r(r), z(z), flag(flag) {}
    bool operator < (const Question &a) const {
        return r < a.r;
    }
} Q[100005];
int ans[50005];

int main(void) {
    memset(son, 0xff, sizeof(son)); 
    scanf("%d%d", &n, &m);
    for (int i = 2; i <= n; ++i) { 
        scanf("%d", f + i); ++f[i]; 
        G[f[i]].emplace_back(i); 
    } 
    dfs1(1); dfs2(1, 1);
    for (int i = 1; i <= m; ++i) {
        int l, r, z; scanf("%d%d%d", &l, &r, &z); ++l, ++r, ++z;
        Q[(i << 1) - 1] = Question(i, l - 1, z, -1);
        Q[(i << 1)] = Question(i, r, z, 1);
    }
    sort(Q + 1, Q + m * 2 + 1);
    int r = 0;
    for (int i = 1; i <= m * 2; ++i) {
        while (r < Q[i].r) update(1, ++r);
        ans[Q[i].id] += Q[i].flag * query(1, Q[i].z);
    }
    for (int i = 1; i <= m; ++i) printf("%d\n", (ans[i] + MOD) % MOD);
    return 0;
}

[Ynoi2017] 由乃的 OJ

Portal.

回想当初的贪心是怎么做的,于是我们用线段树维护 0,10,1 经过这些运算之后会变成什么,要维护 6464 个,因此可以压进一个 unsigned long long 进行维护。由于树剖的特性,所以正反都需要维护。合并的时候很好处理,可以讨论一下当前位什么时候是 11

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef unsigned long long u64; 
const u64 INF = 0 - 1; 

int n, m, k, op[100005]; 
u64 val[100005]; 
int f[100005], siz[100005], dep[100005];
int top[100005], dfn[100005], idx[100005], num, son[100005];
vector<int> G[100005]; 

void dfs1(int x, int fa) {
    dep[x] = dep[f[x] = fa] + 1; siz[x] = 1; 
    for (int y : G[x]) if (y != fa) {
        dfs1(y, x); siz[x] += siz[y]; 
        if (siz[y] > siz[son[x]]) son[x] = y; 
    }
}
void dfs2(int x, int topf) {
    idx[dfn[x] = ++num] = x; top[x] = topf; 
    if (son[x]) dfs2(son[x], topf);
    for (int y : G[x]) if (y != f[x] && y != son[x]) dfs2(y, y);  
}

u64 calc(u64 v, int x) {
    if (op[x] == 1) return v & val[x]; 
    if (op[x] == 2) return v | val[x]; 
    return v ^ val[x]; 
}
struct Node {
    u64 a0, a1, b0, b1; 
    Node() : a0(0), a1(0), b0(0), b1(0) {}
    // a: 0/1 从左到右
    // b: 0/1 从右到左
} T[400005];
inline Node operator+ (const Node &a, const Node &b) {
    Node c; 
    c.a0 = ((a.a0 & b.a1) | (~a.a0 & b.a0));
    c.a1 = ((a.a1 & b.a1) | (~a.a1 & b.a0)); 
    c.b0 = ((b.b0 & a.b1) | (~b.b0 & a.b0));
    c.b1 = ((b.b1 & a.b1) | (~b.b1 & a.b0)); 
    return c; 
}
void build(int o, int l, int r) {
    if (l == r) {
        T[o].a0 = T[o].b0 = calc(0, idx[l]); 
        T[o].a1 = T[o].b1 = calc(INF, idx[l]); 
        return; 
    } int mid = l + r >> 1; 
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
    T[o] = T[o << 1] + T[o << 1 | 1]; 
}
void update(int o, int l, int r, int x) {
    if (l == r) {
        T[o].a0 = T[o].b0 = calc(0, idx[l]); 
        T[o].a1 = T[o].b1 = calc(INF, idx[l]); 
        return;
    } int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x); 
    else update(o << 1 | 1, mid + 1, r, x); 
    T[o] = T[o << 1] + T[o << 1 | 1]; 
}
Node query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1; 
    if (y <= mid) return query(o << 1, l, mid, x, y);
    if (mid < x) return query(o << 1 | 1, mid + 1, r, x, y);
    return query(o << 1, l, mid, x, y) + query(o << 1 | 1, mid + 1, r, x, y);
}

Node ans1[100005], ans2[100005]; 
int tot1, tot2; 
Node query(int x, int y) {
    tot1 = tot2 = 0; 
    while (top[x] != top[y]) {
        if (dep[top[x]] >= dep[top[y]]) {
            ans1[++tot1] = query(1, 1, n, dfn[top[x]], dfn[x]); 
            x = f[top[x]]; 
        } else {
            ans2[++tot2] = query(1, 1, n, dfn[top[y]], dfn[y]); 
            y = f[top[y]]; 
        }
    }
    if (dep[x] > dep[y]) ans1[++tot1] = query(1, 1, n, dfn[y], dfn[x]); 
    else ans2[++tot2] = query(1, 1, n, dfn[x], dfn[y]); 
    for (int i = 1; i <= tot1; ++i) swap(ans1[i].a0, ans1[i].b0), swap(ans1[i].a1, ans1[i].b1);
    Node ans; 
    if (tot1) {
        ans = ans1[1]; 
        for (int i = 2; i <= tot1; ++i) ans = ans + ans1[i];
        if (tot2) ans = ans + ans2[tot2]; 
    } else ans = ans2[tot2]; 
    for (int i = tot2 - 1; i >= 1; --i) ans = ans + ans2[i]; 
    return ans; 
}

int main(void) {
    scanf("%d%d%d", &n, &m, &k); 
    for (int i = 1; i <= n; ++i) scanf("%d%llu", op + i, val + i); 
    for (int i = 1; i < n; ++i) {
        int u, v; scanf("%d%d", &u, &v); 
        G[u].emplace_back(v); G[v].emplace_back(u); 
    } dfs1(1, 0); dfs2(1, 1); build(1, 1, n); 
    while (m--) {
        int opt, x, y; u64 z; scanf("%d%d%d%llu", &opt, &x, &y, &z); 
        if (opt == 1) {
            u64 ans = 0; Node t = query(x, y); 
            for (int i = 63; i >= 0; --i) {
                u64 t0 = (t.a0 >> i) & 1ull; 
                u64 t1 = (t.a1 >> i) & 1ull; 
                if ((1ull << i) > z || t0 >= t1) ans |= (t0 ? (1ull << i) : 0);
                else ans |= (t1 ? (1ull << i) : 0), z -= (1ull << i);
            }
            printf("%llu\n", ans); 
        } else {
            op[x] = y; val[x] = z; 
            update(1, 1, n, dfn[x]); 
        }
    }
    return 0;
}

[CF1017G] The Tree

Portal.

维护一棵树:

  • 1 x:如果 xx 为白色,那么将其染黑,否则对这个节点的儿子进行递归操作;
  • 2 x:将 xx 子树上的所有节点染成白色。
  • 3 x:查询 xx 的颜色。

关键问题是,这个 1 操作是什么鬼?如果节点 yy 会被 1 x 影响到,那说明 xyx\sim y 中除了 yy 的节点都被染黑了。

简单树上乱搞

一些有趣的题。

[Ynoi Easy Round 2021] TEST_68

Portal.

发现很多点的答案应该是一样的。我们将所有点加入 01-Trie,找出一个异或值最大的点对 (p,q)(p,q),只有 (1,p),(1,q)(1,p),(1,q) 这两条链上的答案可能与最大答案不同。考虑由 11 开始遍历链(因为 11 的限制是最严的,而两条链分别遍历一次),不断解放树上的节点,将它们加入 Trie,找出异或的最大值。

时间复杂度 O(nlogV)O(n\log V)

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 
const int N = 500005; 

i64 a[500005];

int val[N * 60], ch[N * 60][2], tot; 
void insert(int id) {
    int x = 0; 
    for (int i = 59; i >= 0; --i) {
        int c = a[id] >> i & 1; 
        if (!ch[x][c]) ch[x][c] = ++tot; 
        x = ch[x][c]; 
    }
    val[x] = id;
}
int query(int id) { // 返回异或值最大的节点编号
    int x = 0; 
    for (int i = 59; i >= 0; --i) {
        int c = a[id] >> i & 1; 
        if (!ch[x][c ^ 1]) x = ch[x][c]; 
        else x = ch[x][c ^ 1];
    }
    return val[x]; 
}

int n, p, q; 
int f[500005], dep[500005], son[2];
bool v[500005]; 
i64 Ans[500005], ans, s[500005]; 
vector<int> G[500005]; 

int LCA(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y); 
    v[x] = v[y] = 1; 
    while (dep[x] > dep[y]) v[x = f[x]] = 1; 
    if (x == y) return x;
    while (x != y) v[x = f[x]] = 1, v[y = f[y]] = 1; 
    return x; 
}

void dfs2(int x) {
    insert(x); s[x] = max(s[x], a[x] ^ a[query(x)]); 
    for (int y : G[x]) dfs2(y), s[x] = max(s[x], s[y]); 
}

void dfs1(int x, int type) {
    Ans[x] = max(Ans[x], s[f[x]]); s[x] = s[f[x]];
    insert(x); s[x] = max(s[x], a[x] ^ a[query(x)]);
    for (int y : G[x]) if (!v[y] || y == son[type]) dfs2(y), s[x] = max(s[x], s[y]);
    for (int y : G[x]) if (v[y] && y != son[type]) dfs1(y, type);
}

int main(void) {
    scanf("%d", &n); dep[1] = 1; 
    for (int i = 2; i <= n; ++i) {
        scanf("%d", f + i), G[f[i]].emplace_back(i); 
        dep[i] = dep[f[i]] + 1; 
    }
    for (int i = 1; i <= n; ++i) {
        scanf("%lld", a + i); insert(i); int x = query(i);
        if ((a[i] ^ a[x]) > ans) ans = a[p = i] ^ a[q = x];
    }
    
    int lca = LCA(p, q), tmp = lca;
    for (int x : G[lca]) if (v[x]) son[1] = son[0], son[0] = x; 
    while (tmp) v[tmp = f[tmp]] = 1;

    memset(val, 0, sizeof val); memset(ch, 0, sizeof ch); tot = 0; dfs1(1, 0);
    memset(s, 0, sizeof s); 
    memset(val, 0, sizeof val); memset(ch, 0, sizeof ch); tot = 0; dfs1(1, 1);

    for (int i = 1; i <= n; ++i) printf("%lld\n", v[i] ? Ans[i] : ans); 
    return 0;
}

[CF19E] Fairy

Portal.

本来就是二分图的可以随便删。对于非二分图,只能删掉被所有奇环覆盖的边,而且不能被偶环覆盖。找一棵搜索树,然后检查所有返祖边即可(因为搜索树不存在横叉边)。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n, m, cnt, sp, s[10005]; 
vector<pair<int, int> > G[10005]; 
vector<int> ans; 
bool vis[10005], dis[10005], pas[10005]; 

void dfs(int x) {
    vis[x] = 1; 
    for (auto [y, i] : G[x]) if (!vis[y]) {
        dis[y] = dis[x] ^ 1; pas[i] = 1; 
        dfs(y); 
    } else if (!pas[i]) {
        pas[i] = 1; 
        if (dis[y] == dis[x]) {
            ++cnt; 
            sp = i; 
            ++s[x], --s[y]; 
        } else --s[x], ++s[y]; 
    }
}
void dfs2(int x) {
    vis[x] = 1; 
    for (auto [y, i] : G[x]) if (!vis[y]) {
        dfs2(y); 
        if (s[y] == cnt) ans.emplace_back(i); 
        s[x] += s[y]; 
    }
}

int main(void) {
    ios::sync_with_stdio(0); 
    cin >> n >> m; 
    for (int i = 1, u, v; i <= m; ++i) {
        cin >> u >> v; 
        G[u].emplace_back(v, i); 
        G[v].emplace_back(u, i); 
    }
    for (int i = 1; i <= n; ++i) if (!vis[i]) dfs(i); 
    if (cnt == 0) {
        cout << m << "\n"; 
        for (int i = 1; i <= m; ++i) cout << i << " "; 
        return cout << "\n", 0; 
    }
    if (cnt == 1) ans.emplace_back(sp); 
    for (int i = 1; i <= n; ++i) vis[i] = 0; 
    for (int i = 1; i <= n; ++i) if (!vis[i]) dfs2(i); 

    sort(ans.begin(), ans.end()); 
    cout << ans.size() << "\n"; 
    for (int x : ans) cout << x << " "; 
    return cout << "\n", 0;
}

【XR-4】 复读

Portal.

枚举一个轮回能够到达的点,将其所有需要构造的子树合并。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n, m, ans = 1e9, pos, dx;  
struct Node {
    int ls, rs; 
} T[2005], T2[2005]; 

int get(void) {
    int c = getchar() - '0', x = ++n; 
    if (c & 1) T[x].ls = get(); 
    if (c & 2) T[x].rs = get(); 
    return x; 
}

void dfs2(int x, int y) {
    if (x == pos || y == dx) dx = y, y = 1; 
    if (T[x].ls) {
        if (!T2[y].ls) T2[y].ls = ++m; 
        dfs2(T[x].ls, T2[y].ls); 
    }
    if (T[x].rs) {
        if (!T2[y].rs) T2[y].rs = ++m; 
        dfs2(T[x].rs, T2[y].rs); 
    }
}

void dfs(int x, int dep) {
    m = 1; memset(T2, 0, sizeof T2); 
    pos = x; dx = 0; dfs2(1, 1); 
    ans = min(ans, (m - 1) * 2 - dep + 1); 
    if (T[x].ls) dfs(T[x].ls, dep + 1); 
    if (T[x].rs) dfs(T[x].rs, dep + 1); 
}

int main(void) {
    get(); dfs(1, 1); 
    return !printf("%d\n", ans);
}

评论

若无法加载,请尝试刷新,欢迎讨论、交流和提出意见,支持 Markdown 与 LaTeX 语法(公式与文字间必须有空格)!