抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

线段树(Segment Tree)是一种二叉搜索树,1977 年由 Jon Louis Bentley 发明,可以较为灵活且效率较高地解决信息可合并的序列维护问题。而树状数组可以维护序列的前缀和。

更新日志

2023/7/14

开始大规模地更改代码,重构文章。

2023/6/30

补充没有理解透彻的内容,增加的部分内容和习题,删除了冗余的习题。

树状数组

又称 Fenwick 树、二叉索引树(BIT)。支持维护前缀后缀的信息。

概述

树状数组将序列拆分成了恰好 nn 个区间,对于每一个前缀求解都可以拆成 logp\log p 个区间进行求解,而且自带一个卡不掉的 1/21/2 的常数,随机数据下则为 1/41/4 的常数!我们通过 lowbit\operatorname{lowbit} 来支持树状数组的工作。

一个显式的树状数组
一个显式的树状数组

模板,区间和我们可以用前缀和相减来求解,代码如下:

#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int n, m; 
int a[500005]; i64 C[500005]; 

void add(int x, int k) { for (; x <= n; x += x & -x) C[x] += k; }
i64 sum(int x) { i64 r = 0; for (; x; x -= x & -x) r += C[x]; return r; }

int main(void) {
    scanf("%d%d", &n, &m); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i), add(i, a[i]); 
    while (m--) {
        int op, x, y; scanf("%d%d%d", &op, &x, &y); 
        if (op == 1) add(x, y); 
        else printf("%lld\n", sum(y) - sum(x - 1)); 
    }
    return 0;
}

树状数组自身也有许多漂亮的操作,虽然效率上略微胜于线段树和平衡树,但是可扩展性和直观程度上却不如它们。下面我们来看一些必须掌握的。

线性建树

对于树状数组上的每个节点都向上传递,具体过程如下:

for (int i = 1; i <= n; ++i) {
    int x; cin >> x; C[i] += x; 
    if (i + lowbit(i) <= n) C[i + lowbit(i)] += C[i]; 
}

差分与前缀和

树状数组可以轻松维护序列的高阶前缀和,首先将原序列差分可以直接解决区间加单点查询

这里直接给出方法。对于 kk 阶前缀和,写出 (yx+k1k1)\dbinom{y-x+k-1}{k-1} 的多项式形式,然后 yy 表示的是下标,xx 表示的是当前位置的值。时间复杂度 O(kqlogn)O(kq\log n)

权值树状数组

构建原序列的权值数列,然后利用树状数组统计。下面的代码可以快速解决逆序对问题

查看代码
#include <bits/stdc++.h>
#define lowbit(x) (x & (-(x)))

using namespace std;
typedef long long i64;

inline int read(void) {
    int x = 0, c = getchar_unlocked(), f = 1;
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar_unlocked();}
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar_unlocked();
    return x * f;
}

int n, m, a[500005], b[500005];
int C[500005];

void update(int x, int k) {
    while (x <= n) {
        C[x] += k;
        x += lowbit(x);
    }
}

int query(int x) {
    i64 res = 0;
    while (x) {
        res += C[x];
        x -= lowbit(x);
    }
    return res;
}

int main(void) {
    n = read();
    for (int i = 1; i <= n; ++i)
        b[i] = a[i] = read();
    sort(b + 1, b + n + 1);
    m = unique(b + 1, b + n + 1) - (b + 1);
    for (int i = 1; i <= n; ++i)
        a[i] = lower_bound(b + 1, b + m + 1, a[i]) - b;
    i64 ans = 0;
    for (int i = n; i >= 1; --i) {
        ans += query(a[i] - 1);
        update(a[i], 1);
    }
    printf("%lld\n", ans);
    return 0;
}

权值数组也可以实现名次树,但是当强制在线时就寄掉了。但是这引出了一个重要 trick:树状数组倍增。

树状数组二分与倍增

我们当然可以使用二分套树状数组达到 O(nlog2n)O(n\log^2 n) 的复杂度,然而有没有更好的方式适配树状数组这种结构呢?有!倍增!

查询一个权值树状数组里的 kk 小值。

我们从二进制高位到低位枚举,时间复杂度为 O(nlogn)O(n\log n)

// 权值树状数组查询第 k 小
int kth(int k) {
    int sum = 0, x = 0;
    for (int i = 17; i >= 0; --i) { // 需满足 sum < k
        x += 1 << i; // 尝试扩展
        if (x >= n || sum + C[x] >= k)  x -= 1 << i; // x 不在树状数组范围内,或扩展失败
        else sum += C[x];
    }
    return x + 1;
}

简介线段树

“线段树”只是 Segment Tree 的一种称法,因为线段树可以理解为是由很多线段组成的,其它叫法包括区间树(interval tree)、范围树等等。但这些称法一般用于特殊领域(如计算几何),本文均用线段树来代表 Segment Tree。

线段树是一种基于分治思想的二叉树结构,有如下特征:

  • 线段树的每一个节点都代表一个区间。
  • 线段树具有唯一的根节点,代表统计范围,一般为 [1,n][1,n]
  • 线段树的每个叶子节点长度都为 11,形如 [x,x][x,x]
  • 一般我们定义,若 mid=(l+r)÷2mid=\lfloor(l+r)\div2\rfloor,那么节点 [l,r][l,r] 的左子节点是 [l,mid][l,mid],右子节点是 [mid+1,r][mid+1,r]
一棵线段树
一棵线段树

对于上图这棵维护区间 [1,4][1,4] 的线段树而言,可以发现,一个节点的左子节点是它的编号乘 22,右子节点是乘 2211。我们可以利用这点方便地来存储线段树。

但凑巧的是,这颗线段树是满二叉树,其它情况类似于这种:

另一棵线段树
另一棵线段树

其实去掉最后一层这树仍是满二叉树,这种情况依然可以使用上述方法存树。

线段树的存储

线段树的正常写法是堆式线段树。

其实就是只用一个数组 TT 存储线段树,用堆的编号来表示线段树的左儿子和右儿子(lc = o << 1, rc = o << 1 | 1),不过进行操作的时候要多传两个数据 llrr

注意,线段树的节点必须开四倍空间!否则如果遇到非满二叉树的线段树,二倍空间就会爆炸!

在网上你能看到这样一种堆式线段树:

struct N {
    int l, r;
    int val;
} T[4*MAXN];

注意,它是记录了当前节点 oo区间 [l,r][l,r],在传参时可以省掉两个参数(听不懂?那就不管,往下看就行)。
有时候要维护的信息特别复杂,我们会将数组 TT 的类型改为结构体,但还是不会使用记录区间的方式。

一般我们使用堆式线段树中的数组方式,而不记录左右儿子(不记慢不了多少)。接下来若不是特殊情况,我们均使用这种方式。

线段树的建树

接下来我们谈谈如何建树,我们再来看这棵线段树:

嘻嘻,还是我
嘻嘻,还是我

最后一层若当作原序列的值,即 [i,i][i,i] 保存 AiA_i 的值。由于线段树是二叉树结构,可以很方便地从下往上传递信息。以区间和为例,令节点 [l,r][l,r] 表示 i=lrAi\sum\limits_{i=l}^{r}A_i,显然 [l,r]=[l,mid]+[mid+1,r][l,r]=[l,mid]+[mid+1,r](这里的区间代表区间所对应的值)。

比如原序列是 1 2 3 4,那么对于节点 11~77,可推算出它们的值分别为 10 3 7 1 2 3 4

那么建树的代码大概就像这样:

#define L(x) ((x)<<1)
#define R(x) ((x)<<1|1)

// 当然,你也可以不用宏定义。

int T[4*N];

inline void maintain(int o) {
    T[o] = T[o << 1] + T[o << 1 | 1];
    // 从下往上传递信息。事实上你也可以写在需要调用 maintain 函数的函数里,不过有时传递的信息较为复杂,还是建议写一个 maintain 函数。网上有的教程把它写作 pushup,至于为什么,接下来你了解到 pushdown 就知道了。
}

void build(int o, int l, int r) { //o 代表当前维护结点的标号,l 和 r 代表所对应的区间
    if (l == r) return T[o] = a[l], void(); //如果这是叶子节点,赋值
    int mid = l + r >> 1; // 计算中值
    build(o << 1, l, mid); // 为左半段建树
    build(o << 1 | 1, mid + 1, r); // 为右半段建树
    maintain(o); // 计算当前结点的值
}

mainbuild(1, 1, n) 来调用 build
由于每个节点只访问了一次,所以建树的时间复杂度为 O(n)\mathcal{O}(n)

点修改与区间查询

模板

点修改

还是以这棵线段树为例:

我又来啦
我又来啦

根据刚才的数据,初始化后它应该长这样:

初始化后的线段树
初始化后的线段树

我们先来进行点修改,比如要给原序列的第 22 个元素加上 11,那么这棵线段树会怎么变化呢?
可以发现,节点 442211 都会加上 11。线段树就变成了这个样子:

点修改后的线段树
点修改后的线段树

那代码怎么实现呢?一般来讲,根节点 11 总是线段树执行的入口,从根节点出发,递归找到需要修改的叶子节点,这里代码如下:

void update(int o, int l, int r, int x, int k) { //给原序列第 x 个元素加上 k。
    if (l == r) return T[o] += k, void(); // 这是叶子节点,直接加
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, k); // 叶子节点在 [l,mid] 处。
    else update(o << 1 | 1, mid+1, r, x, k); // 叶子节点在 [mid+1,r] 处。
    maintain(o); //重新计算这个节点的值。
}

由于线段树的层数在 log\log 级别,所以点修改的时间复杂度为 O(logn)\mathcal{O}(\log n)

区间查询

查询区间 [l,r][l,r] 的和,从根节点开始,递归执行下列过程:

  1. 若当前区间 [l,r][l,r] 完全覆盖了需要求解的范围,那么直接返回答案。
  2. 若当前区间与左子节点有重叠,访问左子节点 [l,mid][l,mid]
  3. 若当前区间与右子节点有重叠,访问右子节点 [mid+1,r][mid+1,r](注意不是访问左子节点后就不用访问右子节点了)。

那怎么解释这个东西呢?还是看那棵线段树 (它的出镜率为什么这么高)

嗯,又是我
嗯,又是我

比如现在我们要查 [1,3][1,3]
[1,3][1,3][1,2][1,2][3,4][3,4] 都有重叠,所以我们要分别访问。
[1,2][1,2] 完全覆盖,直接返回。
[3,4][3,4] 左子节点有覆盖,右子节点没有,访问左子节点。
[3,3][3,3] 直接返回。
所以答案是 4+3=74+3=7

那么代码就长这样:

int query(int o, int l, int r, int ql, int qr) { //[ql,qr] 是要查的区间
    if (ql <= l && r <= qr) return T[o]; //完全包含
    int mid = l + r >> 1, res = 0;
    // 接下来,只要你在(哪怕只有一个元素),我就查
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr); //左子节点
    if (mid < qr) res += query(o << 1 | 1, mid+1, r, ql, qr); //右子节点
    return res;
}

updatequery 的时间复杂度也是 O(logn)\mathcal{O}(\log n)

以下代码就可以通过刚才的模板了。

查看代码
#include <bits/stdc++.h>
using namespace std;

inline int read(void) {
    int x = 0, c = getchar(), f = 1;
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x * f;
}

int n, m;
int T[2000005];
int a[500005];

void build(int o, int l, int r) {
    if (l == r) return T[o] = a[l], void();
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid+1, r);
    T[o] = T[o << 1] + T[o << 1 | 1];
}

void update(int o, int l, int r, int x, int k) {
    if (l == r) return T[o] += k, void();
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, k);
    else update(o << 1 | 1, mid+1, r, x, k);
    T[o] = T[o << 1] + T[o << 1 | 1];
}

int query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return T[o];
    int mid = l + r >> 1, res = 0;
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr);
    if (mid < qr) res += query(o << 1 | 1, mid+1, r, ql, qr);
    return res;
}

int main(void) {
    n = read(), m = read();
    for (int i = 1; i <= n; ++i) a[i] = read();
    build(1, 1, n);
    while (m--) {
        int op = read(), x = read(), y = read();
        if (op == 1) update(1, 1, n, x, y);
        else printf("%d\n", query(1, 1, n, x, y));
    }
    return 0;
}

Problemset

在讨论区间修改之前,我们先看几道线段树的题目。

[Luogu P4513] 小白逛公园

Portal.

最大子段和可以使用 O(nlogn)\mathcal{O}(n \log n) 的分治法进行求解,因为这个子段要么在序列的左半段,要么在右半段,要么跨越中点。加上多组询问,这就是线段树嘛!

最大和的子段在中点两端好说,现在就来看一下跨越中点的情况。
线段树的每个节点维护三个值:最大子段和、最大前缀和、最大后缀和所对应的区间(此区间是线段树的节点所对应的区间)。那么最大子段和跨越中点时,就是前半区间的最大后缀和,加上后半区间的最大前缀和。

维护四个信息:区间和 sumsum,仅靠左端的最大连续和 lmaxlmax,靠右段的 rmaxrmax,以及区间最大子段和 datdat

query 的时候,我们需要看它是否完全在左区间还是完全在右区间,都不是就是跨区间,需要根据左右节点的查询结果计算当前答案。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;

int n, m;
int a[500005];

struct Node {
    int sum, lmax, rmax, dat;
    Node (int sum = 0, int lmax = 0, int rmax = 0, int dat = 0) :
        sum(sum), lmax(lmax), rmax(rmax), dat(dat) {}
}T[2000005];

inline void maintain(int o) {
    int ls = o << 1, rs = o << 1 | 1;
    T[o].sum = T[ls].sum + T[rs].sum;
    T[o].lmax = max(T[ls].lmax, T[ls].sum + T[rs].lmax);
    T[o].rmax = max(T[rs].rmax, T[rs].sum + T[ls].rmax);
    T[o].dat = max({T[ls].dat, T[rs].dat, T[ls].rmax + T[rs].lmax});
}

void build(int o, int l, int r)
{
    if (l == r)
    {
        T[o].sum = T[o].lmax = T[o].rmax = T[o].dat = a[l];
        return;
    }
    int mid = l + r >> 1, ls = o << 1, rs = o << 1 | 1;
    build(ls, l, mid);
    build(rs, mid + 1, r);
    maintain(o);
}

void update(int o, int l, int r, int x, int k)
{
    if (l == r)
    {
        T[o].sum = T[o].lmax = T[o].rmax = T[o].dat = k;
        return;
    }
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, k);
    else update(o << 1 | 1, mid + 1, r, x, k);
    maintain(o);
}

Node query(int o, int l, int r, int ql, int qr)
{
    if (ql <= l && r <= qr) return T[o];
    int mid = l + r >> 1, ls = o << 1, rs = o << 1 | 1;
    if (qr <= mid) return query(ls, l, mid, ql, qr);
    if (ql > mid) return query(rs, mid + 1, r, ql, qr);
    Node x = query(ls, l, mid, ql, qr), y = query(rs, mid + 1, r, ql, qr), res;
    res.sum = x.sum + y.sum;
    res.lmax = max(x.lmax, x.sum + y.lmax);
    res.rmax = max(y.rmax, y.sum + x.rmax);
    res.dat = max({x.dat, y.dat, x.rmax + y.lmax});
    return res;
}

int main(void)
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    build(1, 1, n);
    while (m--)
    {
        int k, x, y;
        scanf("%d%d%d", &k, &x, &y);
        if (k == 1) 
        {
            if (x > y) swap(x, y);
            printf("%d\n", query(1, 1, n, x, y).dat);
        }
        else update(1, 1, n, x, y);
    }
    return 0;
}

[UVa 1400] “Ray, Pass me the dishes!”

Portal.

给定一个序列和多组询问 (l,r)(l,r),查询区间 [l,r][l,r] 的最大子段和,并给出答案对应的字典序最小的子区间

这回要求输出答案的区间了(),但是也没有什么好怕的。我们只需要记录一个 max_sub 来记录区间。

首先是建树,像这样:

void build(int o, int l, int r)
{
    if (l == r)
    {
        maxsub[o] = make_pair(l, r);
        maxpre[o] = l;
        maxsuf[o] = r;
        return;
    }
    // 以上显然

    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    build(lc, l, mid);
    build(rc, mid+1, r);

    // maintain
}

如何维护这一节点呢?根据刚才所说,对应三种情况:

maxsub[o] = better(maxsub[lc], maxsub[rc]); // 左右区间
maxsub[o] = better(maxsub[o], make_pair(maxsuf[lc], maxpre[rc])); // 跨越中点

其中 better 函数用于比较哪个子区间更好。

maxsufmaxpre 怎么维护呢?以 maxpre 为例子,像这样:

LL v1 = sum(l, maxpre[lc]);
LL v2 = sum(l, maxpre[rc]);
if (v1 == v2) maxpre[o] = min(maxpre[lc], maxpre[rc]); // 右端点肯定是越小越好的
else maxpre[o] = v1 > v2 ? maxpre[lc] : maxpre[rc];

其中 sum 指原序列的区间和,容易用前缀和求解。

接下来是查询,大概像这样:

Interval query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return maxsub[o]; // 在区间范围内
    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    if (qr <= mid) return query(lc, l, mid, ql, qr); // 完全在左半端
    if (ql > mid) return query(rc, mid+1, r, ql, qr); // 完全在右半段
    Interval ans = better(query(lc, l, mid, ql, qr), query(rc, mid+1, r, ql, qr)); // 不跨越中点
    return better(ans, make_pair(calc_suf(lc, l, mid, ql).L, calc_pre(rc, mid+1, r, qr).R)); // 跨越中点
}

注意求解前缀和后缀的函数,这里的写法完全符合刚才的定义,这里给出 calc_pre 的实现,calc_suf 的实现大致相同。

Interval calc_pre(int o, int l, int r, int qr)
{
    if (maxpre[o] <= qr) return make_pair(l, maxpre[o]);  // 完全在查询范围内
    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    if (qr <= mid) return calc_pre(lc, l, mid, qr); // 在左半端
    // 注意它要么完全在左半段要么跨越中点,不可能全在右半段
    Interval ans = make_pair(l, calc_pre(rc, mid+1, r, qr).R); // 跨越中点
    return better(ans, make_pair(l, maxpre[lc])); // 与完全在左半段比较
}

下面是完整代码:

查看代码
#include <iostream>
#include <cstdio>

#define Interval pair<int, int>
#define L first
#define R second
#define LL long long

using namespace std;

inline int read(void)
{
    int x = 0, c = getchar(), f = 1;
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x * f;
}

int n, m;
int a[500005];
Interval maxsub[2000005];
LL prefix_sum[500005];
int maxpre[2000005], maxsuf[2000005];

inline LL sum(int l, int r) {return prefix_sum[r] - prefix_sum[l-1];}
inline LL sum(Interval x) {return sum(x.L, x.R);}
inline Interval better(Interval a, Interval b)
{
    if (sum(a) != sum(b)) return sum(a) > sum(b) ? a : b;
    return a < b ? a : b;
}

void build(int o, int l, int r)
{
    if (l == r)
    {
        maxsub[o] = make_pair(l, r);
        maxpre[o] = l;
        maxsuf[o] = r;
        return;
    }

    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    build(lc, l, mid);
    build(rc, mid+1, r);

    LL v1 = sum(l, maxpre[lc]);
    LL v2 = sum(l, maxpre[rc]);
    if (v1 == v2) maxpre[o] = min(maxpre[lc], maxpre[rc]);
    else maxpre[o] = v1 > v2 ? maxpre[lc] : maxpre[rc];

    v1 = sum(maxsuf[lc], r);
    v2 = sum(maxsuf[rc], r);
    if (v1 == v2) maxsuf[o] = min(maxsuf[lc], maxsuf[rc]);
    else maxsuf[o] = v1 > v2 ? maxsuf[lc] : maxsuf[rc];

    maxsub[o] = better(maxsub[lc], maxsub[rc]);
    maxsub[o] = better(maxsub[o], make_pair(maxsuf[lc], maxpre[rc]));
}

Interval calc_pre(int o, int l, int r, int qr)
{
    if (maxpre[o] <= qr) return make_pair(l, maxpre[o]);
    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    if (qr <= mid) return calc_pre(lc, l, mid, qr);
    Interval ans = make_pair(l, calc_pre(rc, mid+1, r, qr).R);
    return better(ans, make_pair(l, maxpre[lc]));
}

Interval calc_suf(int o, int l, int r, int ql)
{
    if (maxsuf[o] >= ql) return make_pair(maxsuf[o], r);
    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    if (ql > mid) return calc_suf(rc, mid+1, r, ql);
    Interval ans = make_pair(calc_suf(lc, l, mid, ql).L, r);
    return better(ans, make_pair(maxsuf[rc], r));
}

Interval query(int o, int l, int r, int ql, int qr)
{
    if (ql <= l && r <= qr) return maxsub[o];
    int mid = l + r >> 1, lc = o << 1, rc = o << 1 | 1;
    if (qr <= mid) return query(lc, l, mid, ql, qr);
    if (ql > mid) return query(rc, mid+1, r, ql, qr);
    Interval ans = better(query(lc, l, mid, ql, qr), query(rc, mid+1, r, ql, qr));
    return better(ans, make_pair(calc_suf(lc, l, mid, ql).L, calc_pre(rc, mid+1, r, qr).R));
}

int main(void)
{
    int kase = 0;
    while (scanf("%d%d", &n, &m) == 2)
    {
        printf("Case %d:\n", ++kase);
        for (int i = 1; i <= n; ++i) 
        {
            a[i] = read();
            prefix_sum[i] = prefix_sum[i-1] + a[i];
        }
        build(1, 1, n);
        while (m--)
        {
            int l = read(), r = read();
            Interval ans = query(1, 1, n, l, r);
            printf("%d %d\n", ans.L, ans.R);
        }
    }
    return 0;
}

根据以上可以发现,线段树可以维护的是容易按照区间进行划分和合并,这一点又称满足区间可加性。关于这一点,接下来还会详细叙述。

区间 GCD

Portal

这个问题看上去很棘手,怎么办呢?强烈建议读者停下来自行思考——想一想 gcd\gcd 的性质,利用在树状数组学过的内容将原问题转换为可以用点修改实现的。还有一点可以发现:gcd\gcd 满足区间可加性,可以通过小区间的 gcd\gcd 求出大区间的 gcd\gcd

请读者先自行撕烤,然后再看解答。

查看解答

根据 gcd(x,y)=gcd(x,yx)\gcd(x,y)=\gcd(x,y-x),而且还有 gcd(x,y,z)=gcd(x,yx,zy)\gcd(x,y,z)=\gcd(x,y-x,z-y),这是什么?差分序列!那么我们可以用支持单点修改的线段树来解决这个问题,这样的话,Q l r 相当于求 gcd(a[l], query(1, 1, n, l + 1, r)AA 数组的值可以用一个支持“区间修改,单点查询”的树状数组实现。线段树修改时,需要进行两次单点修改。

你可能会问一个问题,负数怎么办?实际上 gcd\gcd 的性质对负数同样成立,但是你的输出总不能是负的吧,所以我们在输出时 abs 一下就好。注意由于有负数,所以 gcd 的代码实现要更改(因为取模运算有坑)。

查看代码
#include <iostream>
#include <cstdio>

using namespace std;
using i64 = long long;

inline i64 read(void) {
    i64 x = 0;
    int c = getchar(), f = 1;
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x * f;
}

i64 gcd(i64 a, i64 b) {
    if (b == 0) return a;
    return gcd(b, (a % b + b) % b);
}

int n, m;
i64 a[500010], b[500010];

class FenwickTree {
    private:
        i64 C[500010];
        #define lowbit(x) (x & -x)
    public:
        inline void add(int x, i64 k) {
            while (x <= n) {
                C[x] += k;
                x += lowbit(x);
            }
        }
        inline i64 sum(int x) {
            i64 res = 0;
            while (x) {
                res += C[x];
                x -= lowbit(x);
            }
            return res;
        }
}F;

class SegmentTree
{
    private:
        i64 T[2000050];
        inline void maintain(int o) {
            T[o] = gcd(T[o << 1], T[o << 1 | 1]);
        }
    public:
        void build(int o, int l, int r) {
            if (l == r) return T[o] = b[l], void();
            int mid = l + r >> 1;
            build(o << 1, l, mid);
            build(o << 1 | 1, mid + 1, r);
            maintain(o);
        }
        void update(int o, int l, int r, int x, i64 k) {
            if (l == r) {
                T[o] += k;
                return;
            }
            int mid = l + r >> 1;
            if (x <= mid) update(o << 1, l, mid, x, k);
            else update(o << 1 | 1, mid + 1, r, x, k);
            maintain(o);
        }
        i64 query(int o, int l, int r, int ql, int qr) {
            if (ql <= l && r <= qr) return T[o];
            int mid = l + r >> 1;
            i64 res = 0;
            if (ql <= mid) res = gcd(res, query(o << 1, l, mid, ql, qr)); 
            if (mid < qr) res = gcd(res, query(o << 1 | 1, mid + 1, r, ql, qr));
            return res;
        }
} S;

int main(void)
{
    n = read(), m = read();
    for (int i = 1; i <= n; ++i) {
        a[i] = read();
        b[i] = a[i] - a[i-1];
    }
    S.build(1, 1, n);
    // 树状数组不建树,到时候直接加上 a[l] 即可
    while (m--) {
        char c;
        cin >> c;
        int l = read(), r = read();
        if (c == 'C') {
            i64 k = read();
            F.add(l, k);
            S.update(1, 1, n, l, k);
            if (r < n) {
                F.add(r + 1, -k);
                S.update(1, 1, n, r + 1, -k);
            }
        }
        else printf("%lld\n", abs(gcd(a[l] + F.sum(l), l < r ? S.query(1, 1, n, l + 1, r) : 0)));
    }
    return 0;
}

区间修改与延迟标记

根据刚才的学习,可以发现线段树是个很厉害的数据结构,但它的威力可不止如此,来看,还有更厉害的:

延迟标记的介绍

[Luogu 3372]【模板】线段树 1

这回可不一样了,点修改只会影响树中的 logn\log n 个节点,而区间修改最坏情况下会影响区间中的所有节点,这可怎么办?我们这里要引入一个叫做“延迟标记”的东西(或者叫它懒标记,即 lazy tag)。

试想,如果我们在一次修改操作中发现节点 oo 代表的区间 [ol,or][o_l,o_r] 中要修改的区间 [l,r][l,r] 被完全覆盖,那么更新点 oo 的子树就是徒劳的。可以给节点 oo 做一个标记,省掉接下来的操作。就是打完标记后我们可以立即返回,此标记代表“该节点曾经被修改过,但其子节点尚未更新”。

如果在后续的指令中,需要从节点 oo 向下递归,那么我们就下传 oo 的标记,并清空 oo 的标记。

接下来我们看一下这道题该怎么写。首先建树和维护当前节点的过程没有变化,而对于修改操作需要这样写:

void update(int o, int l, int r, int x, int y, int k) //区间 [x,y] 加上 k
{
    if (x <= l && r <= y) //在区间范围内
    {
        T[o] += (LL)k * (r-l+1);
        tag[o] += k;
        return;
    }
    pushdown(o, l, r);

    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid+1 <= y) update(o << 1 | 1, mid+1, r, x, y, k);
    maintain(o);
}

这里有几点需要注意,首先是 TT 数组的修改,别忘了这是区间修改,要加上的数需要乘上区间的长度。

然后是 pushdown 函数(有的版本写成 spread),需要这样:

inline void pushdown(int o, int l, int r)
{
    if (tag[o]) //标记不是 0 才有必要下传,但非要传也非不可,慢不了多少
    {
        tag[o << 1] += tag[o];
        tag[o << 1 | 1] += tag[o];
        // 下传标记

        int mid = l + r >> 1;

        // 注意区间的长度
        T[o << 1] += (LL)tag[o] * (mid-l+1);
        T[o << 1 | 1] += (LL)tag[o] * (r-mid);

        tag[o] = 0; // 清除父亲节点的标记(因为下传了)
    }
}

需要分别修改左右儿子标记的值和数值。

最后是递归的过程,由于是区间修改,所以左右都需要判断(mid + 1 <= y 有的版本会写成 mid < y )。

注意查询的时候也需要下传标记(否则你怎么查,子节点没法计算了)。

注意到了吧?pushdown 的反义词是 pushup,所以有人把 maintain 写成 pushup

想一想,为什么以上操作都可以保证最后的结果时间复杂度是正确的呢(建议手玩)?

查询操作道理基本相同,相信大家可以自己写出来。

那么对于这道题而言:

查看代码
#include <iostream>
#include <cstdio>

#define LL long long

using namespace std;

inline int read(void) {
    int x = 0, c = getchar(), f = 1;
    while (!isdigit(c)){if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x * f;
}

int n, m;
LL a[100005];
LL T[400005], tag[400005];

inline void maintain(int o) {
    T[o] = T[o << 1] + T[o << 1 | 1];
}

inline void pushdown(int o, int l, int r) {
    if (tag[o]) {
        tag[o << 1] += tag[o];
        tag[o << 1 | 1] += tag[o];
        int mid = l + r >> 1;
        T[o << 1] += (LL)tag[o] * (mid-l+1);
        T[o << 1 | 1] += (LL)tag[o] * (r-mid);
        tag[o] = 0;
    }
}

void build(int o, int l, int r) {
    if (l == r) return T[o] = a[l], void();
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid+1, r);
    maintain(o);
}

void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) {
        T[o] += (LL)k * (r-l+1);
        tag[o] += k;
        return;
    }
    pushdown(o, l, r); int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid+1 <= y) update(o << 1 | 1, mid+1, r, x, y, k);
    maintain(o);
}

LL query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return T[o];
    int mid = l + r >> 1; LL res = 0; pushdown(o, l, r);
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr);
    if (mid < qr) res += query(o << 1 | 1, mid+1, r, ql, qr);
    return res;
}

int main(void) {
    n = read(), m = read();
    for (int i = 1; i <= n; ++i)
        a[i] = read();
    build(1, 1, n);
    while (m--) {
        int op = read();
        if (op == 1) {
            int x = read(), y = read(), k = read();
            update(1, 1, n, x, y, k);
        } else {
            int x = read(), y = read();
            printf("%lld\n", query(1, 1, n, x, y));
        }
    }
    return 0;
}

是不是有点意思了?还有更复杂的。

多组延迟标记

你以为延迟标记只能由有一组?只要你愿意,都可以整出一百组(不过好像也没有一百组)!

[UVa 11992] Fast Matrix Operations

Portal

有一个 rrcc 列的全零矩阵,矩阵不超过 2020 行,支持子矩阵加,子矩阵赋值和查询子矩阵和、最小值和最大值。

由于矩阵最多有 2020 行,所以可以每行造一棵线段树,那么本体转化为一维问题。

现在由于有两种操作,那么就有两个标记,但两个标记总得有个顺序吧!否则不乱套了!
由于先加后赋值是没有任何意义的,所以我们规定先赋值后加。

值得一提的是,对于这种要维护信息较多的线段树,建议使用结构体,否则代码会显得很乱。

不过这里笔者有点懒,未把自己的代码改成全用结构体,仅在查询时使用了结构体,请大家谅解。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cstring>

#define o << 1 ((o) << 1)
#define o << 1 | 1 (o << 1 | 1)

using namespace std;

struct Answer
{
    int sum, min, max;
    Answer(int s = 0, int i = 1000000002, int a = -1)
    {
        sum = s;
        min = i;
        max = a;
    }
};

inline Answer up(Answer a, Answer b)
{
    return Answer(a.sum + b.sum, min(a.min, b.min), max(a.max, b.max));
}

struct SegmentTree
{
    int sumv[1000005], minv[1000005], maxv[1000005];
    int addv[1000005], setv[1000005];

    inline void init(void)
    {
        // 没有初值,可以这样建树。
        memset(sumv, 0, sizeof(sumv));
        memset(minv, 0, sizeof(minv));
        memset(maxv, 0, sizeof(maxv));
        memset(setv, -1, sizeof(setv));
        memset(addv, 0, sizeof(addv));
    }

    inline void maintain(int o)
    {
        sumv[o] = sumv[o << 1] + sumv[o << 1 | 1];
        minv[o] = min(minv[o << 1], minv[o << 1 | 1]);
        maxv[o] = max(maxv[o << 1], maxv[o << 1 | 1]);
    }

    inline void pushdown(int o, int l, int r)
    {
        int mid = l + r >> 1;
        // 先搞 set,再搞 add
        if (setv[o] >= 0)
        {
            setv[o << 1] = setv[o << 1 | 1] = setv[o];
            addv[o << 1] = addv[o << 1 | 1] = 0; // 有 set 标记需清空 add 标记。

            sumv[o << 1] = (mid - l + 1) * setv[o];
            sumv[o << 1 | 1] = (r - mid) * setv[o];
            minv[o << 1] = minv[o << 1 | 1] = maxv[o << 1] = maxv[o << 1 | 1] = setv[o];

            setv[o] = -1;
        }
        if (addv[o] > 0)
        {
            addv[o << 1] += addv[o];
            addv[o << 1 | 1] += addv[o];
            
            sumv[o << 1] += (mid - l + 1) * addv[o];
            sumv[o << 1 | 1] += (r - mid) * addv[o];
            minv[o << 1] += addv[o];
            minv[o << 1 | 1] += addv[o];
            maxv[o << 1] += addv[o];
            maxv[o << 1 | 1] += addv[o];

            addv[o] = 0;
        }
    }

    inline void update_add(int o, int l, int r, int x, int y, int k)
    {
        if (x <= l && r <= y)
        {
            addv[o] += k;
            sumv[o] += (r - l + 1) * k;
            minv[o] += k;
            maxv[o] += k;
            return;
        }
        pushdown(o, l, r);
        int mid = l + r >> 1;
        if (x <= mid) update_add(o << 1, l, mid, x, y, k);
        if (mid + 1 <= y) update_add(o << 1 | 1, mid+1, r, x, y, k);
        maintain(o);
    }

    inline void update_set(int o, int l, int r, int x, int y, int k)
    {
        if (x <= l && r <= y)
        {
            addv[o] = 0;
            setv[o] = k;
            sumv[o] = (r - l + 1) * k;
            minv[o] = maxv[o] = k;
            return;
        }
        pushdown(o, l, r);
        int mid = l + r >> 1;
        if (x <= mid) update_set(o << 1, l, mid, x, y, k);
        if (mid + 1 <= y) update_set(o << 1 | 1, mid+1, r, x, y, k);
        maintain(o);
    }

    // 强烈不建议在这里使用全局变量计算答案,这是禁忌,会让代码很乱。
    inline Answer query(int o, int l, int r, int ql, int qr)
    {
        if (ql <= l && r <= qr) return Answer(sumv[o], minv[o], maxv[o]);
        pushdown(o, l, r);
        int mid = l + r >> 1;
        Answer res;
        if (ql <= mid) res = up(res, query(o << 1, l, mid, ql, qr));
        if (qr >= mid + 1) res = up(res, query(o << 1 | 1, mid+1, r, ql, qr));
        return res;
    }
}T[21];

inline int read(void)
{
    int x = 0, c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x;
}

int r, c, m;

int main(void)
{
    while (scanf("%d%d%d", &r, &c, &m) == 3)
    {
        for (int i = 1; i <= r; ++i) T[i].init();
        while (m--)
        {
            int op = read();
            if (op == 1)
            {
                int x1 = read(), y1 = read(), x2 = read(), y2 = read(), v = read();
                for (int i = x1; i <= x2; ++i)
                    T[i].update_add(1, 1, c, y1, y2, v);
            }
            else if (op == 2)
            {
                int x1 = read(), y1 = read(), x2 = read(), y2 = read(), v = read();
                for (int i = x1; i <= x2; ++i)
                    T[i].update_set(1, 1, c, y1, y2, v);
            }
            else
            {
                int x1 = read(), y1 = read(), x2 = read(), y2 = read();
                int sumr = 0, minr = 1000000002, maxr = -1;
                for (int i = x1; i <= x2; ++i)
                {
                    Answer ret = T[i].query(1, 1, c, y1, y2);
                    sumr += ret.sum;
                    minr = min(minr, ret.min);
                    maxr = max(maxr, ret.max);
                }
                printf("%d %d %d\n", sumr, minr, maxr);
            }
        }
    }   
    return 0;
}

[AHOI2009] 维护序列

Portal.

区间加,区间乘,区间求和。

根据刚才的经验,要么是现加后乘,要么是先乘后加,但是都可以吗?注意,先加后乘是无法表示的,因为当乘的标记袭来后,原来的加的标记就必须变成一个分数,这就完蛋了。

实现较为简单,这里不做赘述。

查看代码
#include <iostream>
#include <cstdio>

#define LL long long
#define o << 1 ((o) << 1)
#define o << 1 | 1 (o << 1 | 1)

using namespace std;

int n, p;
int a[100005];
int T[400005];
int addv[400005], mulv[400005];

inline int read(void) {
    int x = 0, c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x;
}

inline void maintain(int o) {
    T[o] = (T[o << 1] + T[o << 1 | 1]) % p;
}

void build(int o, int l, int r)
{
    mulv[o] = 1;
    if (l == r)
    {
        T[o] = a[l] % p;
        return;
    }
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid+1, r);
    maintain(o);
}

inline void pushdown(int o, int l, int r)
{
    int mid = l + r >> 1;

    T[o << 1] = int(((LL)T[o << 1] * mulv[o] % p + (LL)addv[o] * (mid - l + 1) % p) % p);
    T[o << 1 | 1] = int(((LL)T[o << 1 | 1] * mulv[o] % p + (LL)addv[o] * (r - mid) % p) % p);

    mulv[o << 1] = int((LL)mulv[o << 1] * mulv[o] % p);
    mulv[o << 1 | 1] = int((LL)mulv[o << 1 | 1] * mulv[o] % p);

    addv[o << 1] = int(((LL)addv[o << 1] * mulv[o] + addv[o]) % p);
    addv[o << 1 | 1] = int(((LL)addv[o << 1 | 1] * mulv[o] + addv[o]) % p);

    mulv[o] = 1;
    addv[o] = 0;
}

void update_mul(int o, int l, int r, int x, int y, int k)
{
    if (x <= l && r <= y)
    {
        addv[o] = int(addv[o] * (LL)k % p);
        mulv[o] = int(mulv[o] * (LL)k % p);
        T[o] = int((LL)T[o] * k % p);
        return;
    }
    pushdown(o, l, r);
    int mid = l + r >> 1;
    if (x <= mid) update_mul(o << 1, l, mid, x, y, k);
    if (mid + 1 <= y) update_mul(o << 1 | 1, mid+1, r, x, y, k);
    maintain(o);
}

void update_add(int o, int l, int r, int x, int y, int k)
{
    if (x <= l && r <= y)
    {
        addv[o] = (addv[o] + k) % p;
        T[o] = int((T[o] + (LL)k * (r - l + 1)) % p);
        return;
    }
    pushdown(o, l, r);
    int mid = l + r >> 1;
    if (x <= mid) update_add(o << 1, l, mid, x, y, k);
    if (mid + 1 <= y) update_add(o << 1 | 1, mid+1, r, x, y, k);
    maintain(o);
}

int query(int o, int l, int r, int ql, int qr)
{
    if(ql <= l && r <= qr) return T[o];
    pushdown(o, l, r);
    int mid = l + r >> 1, res = 0;
    if (ql <= mid) res = (res + query(o << 1, l, mid, ql, qr)) % p;
    if (mid + 1 <= qr) res = (res + query(o << 1 | 1, mid+1, r, ql, qr)) % p;
    return res;
}

int main(void)
{
    n = read(), p = read();
    for (int i = 1; i <= n; ++i) a[i] = read();
    int m = read();
    build(1, 1, n);
    while (m--)
    {
        int op = read(), x = read(), y = read();
        if (op == 1)
        {
            int k = read();
            update_mul(1, 1, n, x, y, k);
        }
        else if (op == 2)
        {
            int k = read();
            update_add(1, 1, n, x, y, k);
        }
        else printf("%d\n", query(1, 1, n, x, y));
    }
    return 0;
}

线段树的本质

线段树能干什么呢?

区间可加性

记得之前提到的“区间可加性”吗?刚才的区间乘方操作满足这一性质吗?线段树的工作原理是将两个小区间的值合并成大区间的值。比如在最初的区间加区间查询问题中,我们可以通过 i=lmidAi+i=mid+1rAi\sum_{i=l}^{mid}A_i+\sum_{i=mid+1}^{r}A_i 来得到 i=lrAi\sum_{i=l}^{r}A_i,可以合并。

延迟标记与其它

标记是什么?它是一个“欠条”,相当于告诉线段树我在这欠了东西,继续向下递归需要 pushdown。而且能标记的东西必须可以高效更新当前节点的信息。

不下传标记查询没有办法进行?实际上标记可以永久化,就是在查询的时候累计一下标记,而且常数会小一点。但是只限于特殊的标记,比如区间加是可以做的,以最初的区间加区间查询和为例:

查看代码
#include <bits/stdc++.h>
using namespace std; 
typedef long long i64; 

int n, m, a[400005]; 
i64 T[400005], tag[400005]; 

void build(int o, int l, int r) {
    if (l == r) return T[o] = a[l], void(); 
    int mid = l + r >> 1; 
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
    T[o] = T[o << 1] + T[o << 1 | 1]; 
}
void update(int o, int l, int r, int x, int y, i64 k) {
    if (x <= l && r <= y) return tag[o] += k, T[o] += (r - l + 1) * k, void(); 
    int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, y, k); 
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k); 
    T[o] = T[o << 1] + T[o << 1 | 1] + tag[o] * (r - l + 1); 
}
i64 query(int o, int l, int r, int x, int y, i64 t) {
    if (x <= l && r <= y) return T[o] + t * (r - l + 1); 
    int mid = l + r >> 1; i64 ans = 0; t += tag[o]; 
    if (x <= mid) ans += query(o << 1, l, mid, x, y, t); 
    if (mid < y) ans += query(o << 1 | 1, mid + 1, r, x, y, t); 
    return ans; 
}

int main(void) {
    scanf("%d%d", &n, &m); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i); 
    build(1, 1, n); 
    while (m--) {
        int op, l, r, k; scanf("%d%d%d", &op, &l, &r); 
        if (op == 1) scanf("%d", &k), update(1, 1, n, l, r, k); 
        else printf("%lld\n", query(1, 1, n, l, r, 0)); 
    }
    return 0; 
}

但是区间赋值不行,因为查询的过程中无法累加,操作的先后顺序会改变结果,不知道哪个是先做的,无法维护。当标记难以下传时,可以考虑使用标记永久化。

权值线段树

对于序列 AA 构造一个序列 BB,其中 BiB_i 表示 AA 中数值 ii 出现的次数,也就是 aj=ia_j=ijj 的个数,这样的 BB 称之为 AA权值数列,对 BB 造一棵线段树就是权值线段树

主要应用于一些计数问题,和可持久化搭配有奇效。为了使以后主席树(可持久化权值线段树,应用很多)的学习更加顺利,我们这里通过一道题来谈一下代码实现:

逆序对

啊,不要问我问什么是这道题,因为它太经典了。

我们知道这道题可以用归并排序或者树状数组解决。今天我们再来用权值线段树解决它。

这种东西一般都需要先离散化。考虑枚举 jj,对于每个 jj 只需要找到在它之前有多少个大于它的 aia_i 即可,对 AAj1j-1 位建立权值线段树,每次只需要查询线段树上 [aj+1,n][a_j+1,n] 的和即可,然后修改对于权值线段树来说就是点修改。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n;
int a[500005], T[2000005];

void init(void) {
    static int tmp[500005];
    for (int i = 1; i <= n; ++i)
        tmp[i] = a[i];
    sort(tmp + 1, tmp + n + 1);
    int m = unique(tmp + 1, tmp + n + 1) - (tmp + 1);
    for (int i = 1; i <= n; ++i)
        a[i] = lower_bound(tmp + 1, tmp + m + 1, a[i]) - tmp;
}

int query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return T[o];
    int mid = l + r >> 1, res = 0;
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr);
    if (mid < qr) res += query(o << 1 | 1, mid + 1, r, ql, qr);
    return res;
}

void update(int o, int l, int r, int x) {
    if (l == r) return T[o]++, void();
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x);
    else update(o << 1 | 1, mid + 1, r, x);
    T[o] = T[o << 1] + T[o << 1 | 1];
}

int main(void) {
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    init();
    long long ans = 0;
    for (int i = 1; i <= n; ++i) {
        ans += query(1, 1, n, a[i] + 1, n);
        update(1, 1, n, a[i]);
    }
    printf("%lld\n", ans);
    return 0;
}

注意,虽然上述代码是正确的,但是对这道题来说显然不是最优的,因为查询不是简单的区间查询,而是一端固定的区间。但是用更通用的方式来写显然不易出错。

动态开点线段树

通过记录左右儿子的编号,而不是使用完全二叉树的编号法则,这种方式称之为动态开点。代码大概长这样:

struct Node {
    int lc, rc; // 左右节点编号
    int dat; // 当前维护的值
}T[SIZE * 2]; // 终于只需要二倍空间啦!
int root, tot; // 根节点编号,节点个数

int newNode(void) {
    ++tot;
    T[tot].lc = T[tot].rc = T[tot].dat = 0;
    return tot;
}

void update(int o, int l, int r, int x, int k) {
    if (l == r) return T[o].dat += k, void(); 
    int mid = l + r >> 1;
    if (x <= mid) {
        if (!T[o].lc) T[o].lc = build();
        update(T[o].lc, l, mid, x, k);
    } else {
        if (!T[o].rc) T[o].rc = build();
        update(T[o].rc, mid + 1, r, x, k);
    } maintain(o);
}

int main(void) {
    tot = 0;
    root = build(); // 建树
}

线段树二分

权值线段树上是可以二分的。

[PA2015] Siano.

一片 nn 亩的土地,第 ii 亩土地的草每天会长高 aia_i 厘米。一共会进行 mm 次收割,其中第 ii 次收割在第 did_i 天,并把所有高度大于等于 bib_i 的部分全部割去。

每次收割得到的草的高度总和是多少?

首先发现一个问题,长得快的草一定长得高。那么将草的生长速度从小到大排序,每次割掉的一定是一个后缀区间。

使用线段树维护,查询时在线段树上二分(递归时看左子树是否满足,然后判断进入哪一棵子树)出最后一个大于等于 bb 的点即可。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64;

int n, m;
int a[500005];
i64 s[500005];
struct Node {
    i64 sum, setv, addv, maxx;
    Node() : setv(-1) {}
} T[2000005];

void grow(int o, int l, int r, i64 val) {
    T[o].addv += val;
    T[o].sum += (s[r] - s[l - 1]) * val;
    T[o].maxx += a[r] * val;
}
void cut(int o, int l, int r, i64 val) {
    T[o].setv = T[o].maxx = val; T[o].addv = 0;
    T[o].sum = (r - l + 1) * val;
}
void pushdown(int o, int l, int r) {
    int mid = l + r >> 1;
    if (T[o].setv != -1) {
        cut(o << 1, l, mid, T[o].setv);
        cut(o << 1 | 1, mid + 1, r, T[o].setv);
        T[o].setv = -1;
    }
    if (T[o].addv) {
        grow(o << 1, l, mid, T[o].addv);
        grow(o << 1 | 1, mid + 1, r, T[o].addv);
        T[o].addv = 0;
    }
}

i64 modify(int o, int l, int r, int x, int y, i64 val) {
    if (x > y) return 0;
    if (x <= l && r <= y) {
        i64 tmp = T[o].sum; cut(o, l, r, val);
        return tmp - T[o].sum;
    }
    i64 res = 0; int mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) res += modify(o << 1, l, mid, x, y, val);
    if (mid < y) res += modify(o << 1 | 1, mid + 1, r, x, y, val);
    T[o].sum = T[o << 1].sum + T[o << 1 | 1].sum; 
    T[o].maxx = T[o << 1 | 1].maxx;
    return res;
}
int find(int o, int l, int r, i64 val) {
    if (l == r) return T[o].sum < val ? n + 1 : l;
    int mid = l + r >> 1; pushdown(o, l, r);
    if (T[o << 1].maxx >= val) return find(o << 1, l, mid, val);
    return find(o << 1 | 1, mid + 1, r, val);
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i); 
    sort(a + 1, a + n + 1);
    for (int i = 1; i <= n; ++i) s[i] = s[i - 1] + a[i];
    i64 last = 0;
    while (m--) {
        i64 d, b; scanf("%lld%lld", &d, &b); 
        grow(1, 1, n, d - last); last = d;
        printf("%lld\n", modify(1, 1, n, find(1, 1, n, b), n, b));
    }
    return 0;
}

线段树的分裂与合并

对于动态开点的权值线段树,它们可以进行分裂和合并的操作。

线段树合并

假设现在有两棵维护相同值域的基于动态开点实现的权值线段树,现在我们想要将它们维护的值相加。这就需要通过线段树合并来实现,从两个根节点开始同步遍历两棵线段树,也就是说,两个指针 o1,o2o_1,o_2,在实现中采用 p,qp,q,所代表的子区间是一致的。

如果两个其中之一为空,那么返回那个非空的。如果都不是空的,那么需要递归合并两棵子树,然后删去节点 qq,以 pp 作为合并的节点(维护最大值)。

int merge(int p, int q, int l, int r) {
    if (!p) return q; if (!q) return p;
    if (l == r) {
        T[p].dat += T[q].dat;
        return p;
    }
    int mid = l + r >> 1;
    T[p].lc = merge(T[p].lc, T[q].lc, l, mid);
    T[p].rc = merge(T[p].rc, T[q].rc, mid + 1, r);
    T[p].dat = max(T[T[p].lc].dat, T[T[p].rc].dat);
    return p;
}

时间复杂度与线段树的规模一致。这样将 qq 合并到 pp 之后会导致 qq 的结构被破坏,所以这样只能离线。如果实时新建节点可以做到在线,这样的空间复杂度为 O(nlogn)O(n\log n)

int merge(int p, int q, int l, int r) {
    if (p == 0 || q == 0) return p + q; 
    int o = ++tot; 
    if (l == r) {
        T[o].dat = T[p].dat + T[q].dat;
        return o; 
    } int mid = l + r >> 1; 
    T[o].ls = merge(T[p].ls, T[q].rs, l, mid); 
    T[o].rs = merge(T[p].rs, T[q].rs, mid + 1, r); 
    pushup(o); return o; 
}

模板。差分操作,对每一个节点都使用一棵动态开点权值线段树来维护信息,最后前缀和一次做线段树合并回答询问。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 100000; 

int n, m, dep[100005], mi[17][100005], dfn[100005], num, lg[100005]; 
int f[100005], ans[100005]; 
vector<int> G[100005]; 

inline int get(int x, int y) { return dep[x] < dep[y] ? x : y; }
void dfs(int x, int fa) {
    mi[0][dfn[x] = ++num] = fa; dep[x] = dep[fa] + 1; f[x] = fa; 
    for (int y : G[x]) if (y != fa) dfs(y, x);
}
int LCA(int u, int v) {
    if (u == v) return u; 
    if ((u = dfn[u]) > (v = dfn[v])) swap(u, v);
    int k = lg[v - u];
    return get(mi[k][u + 1], mi[k][v - (1 << k) + 1]);
}

struct Node {
    int ls, rs;
    int cnt, ans; 
} T[6000005];
int root[100005], tot; 
inline void pushup(int o) {
    if (T[T[o].ls].cnt >= T[T[o].rs].cnt) T[o].cnt = T[T[o].ls].cnt, T[o].ans = T[T[o].ls].ans; 
    else T[o].cnt = T[T[o].rs].cnt, T[o].ans = T[T[o].rs].ans; 
}
void update(int &o, int l, int r, int x, int k) {
    if (!o) o = ++tot; 
    if (l == r) return T[o].cnt += k, T[o].ans = x, void(); 
    int mid = l + r >> 1; 
    if (x <= mid) update(T[o].ls, l, mid, x, k);
    else update(T[o].rs, mid + 1, r, x, k);
    pushup(o);
}
int merge(int p, int q, int l, int r) {
    if (!p || !q) return p + q; 
    if (l == r) return T[p].cnt += T[q].cnt, p; 
    int mid = l + r >> 1; 
    T[p].ls = merge(T[p].ls, T[q].ls, l, mid); 
    T[p].rs = merge(T[p].rs, T[q].rs, mid + 1, r);
    pushup(p); return p; 
}
void calc(int x, int fa) {
    for (int y : G[x]) if (y != fa) {
        calc(y, x); 
        root[x] = merge(root[x], root[y], 1, N);
    }
    ans[x] = T[root[x]].ans; 
    if (!T[root[x]].cnt) ans[x] = 0; 
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1; 
    for (int i = 1; i < n; ++i) {
        int u, v; scanf("%d%d", &u, &v); 
        G[u].emplace_back(v); G[v].emplace_back(u);
    } dfs(1, 0);
    for (int i = 1; i <= lg[n]; ++i)
        for (int j = 1; j + (1 << i) - 1 <= n; ++j)
            mi[i][j] = get(mi[i - 1][j], mi[i - 1][j + (1 << i - 1)]);
    while (m--) {
        int x, y, z; scanf("%d%d%d", &x, &y, &z);
        int d = LCA(x, y);
        update(root[x], 1, N, z, 1); update(root[y], 1, N, z, 1);
        update(root[d], 1, N, z, -1); update(root[f[d]], 1, N, z, -1);
    } calc(1, 0);
    for (int i = 1; i <= n; ++i) printf("%d\n", ans[i]);
    return 0;
}

线段树分裂

是将一个可重集前 kk 小的数之后的数分成两个集合,这样线段树就会分裂成两棵线段树。

可以仿照 FHQ-Treap 的思路,我们可以实现 O(logn)O(\log n) 的线段树分裂。

代填坑。

Problemset

感觉内容很多?的确如此,基础数据结构可以解决很多问题,下面是一些经典题。

简单问题

主要如何拆分或变形要处理的内容,使得更容易维护。以及如何合并简单的标记。

[Luogu P1438] 无聊的数列

Portal.

区间加等差数列,单点查询。

等差数列看作一个整体当成标记的话非常难维护,因为首项一直在改变。对于一次操作,可以拆成对区间的 k-d*l(为当前的 ll)和 +d*i(为当前下标)。这个 +d*? 的操作只需要开一个标记,然后再查询的时候乘上当前的 ll 就可以了。代码如下:

查看代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;

int n, m;
int a[100005];
i64 tagk[400005], tagd[400005];

inline void pushdown(int o, int l, int r)
{
    int mid = l + r >> 1;
    tagk[o << 1] += tagk[o], tagk[o << 1 | 1] += tagk[o];
    tagd[o << 1] += tagd[o], tagd[o << 1 | 1] += tagd[o];
    tagk[o] = tagd[o] = 0;
}

void update(int o, int l, int r, int x, int y, int k, int d)
{
    if (x <= l && r <= y)
    {
        tagk[o] += k;
        tagd[o] += d;
        return;
    }
    pushdown(o, l, r);
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y, k, d);
    if (mid + 1 <= y) update(o << 1 | 1, mid + 1, r, x, y, k, d);
}

i64 query(int o, int l, int r, int p)
{
    if (l == r) return tagk[o] + l * tagd[o];
    int mid = l + r >> 1;
    pushdown(o, l, r);
    if (p <= mid) return query(o << 1, l, mid, p);
    return query(o << 1 | 1, mid + 1, r, p);
}

int main(void)
{
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", &a[i]);
    while (m--)
    {
        int opt; scanf("%d", &opt);
        if (opt == 1)
        {
            int l, r; i64 k, d;
            scanf("%d%d%lld%lld", &l, &r, &k, &d);
            update(1, 1, n, l, r, k - d * l, d);
        }
        else 
        {
            int p; scanf("%d", &p);
            printf("%lld\n", query(1, 1, n, p) + a[p]);
        }
    }
    return 0;
}

为什么是单点查询?因为查询的时候每个下标是变化的,+d*?? 一直在变化,只能做单点。

[Luogu P6327] 区间加区间 sin 和

Portal.

高中课本介绍了三角函数的和差角公式:

sin(α+β)=sinαcosβ+cosαsinβcos(α+β)=cosαcosβsinαsinβ\sin(\alpha+\beta)=\sin \alpha \cos \beta + \cos\alpha\sin\beta\\ \cos(\alpha+\beta)=\cos\alpha\cos\beta - \sin\alpha\sin\beta

维护两个量 SinCos,记录一个标记 tagupdatepushdown 的时候用公式维护加上标记的值的三角函数值即可。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int read(void) {
    int x = 0, c = getchar_unlocked();
    while (!isdigit(c)) c = getchar_unlocked();
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar_unlocked();
    return x;
}

int n, m;
int a[200005];
i64 tag[800005];
double Sin[800005], Cos[800005];

inline void maintain(int o) {
    Sin[o] = Sin[o << 1] + Sin[o << 1 | 1];
    Cos[o] = Cos[o << 1] + Cos[o << 1 | 1];
}

inline void maintain(int o, double sinx, double cosx) {
    double sina = Sin[o], cosa = Cos[o];
    Sin[o] = sina * cosx + cosa * sinx;
    Cos[o] = cosa * cosx - sina * sinx;
}

void build(int o, int l, int r) {
    if (l == r) {
        Sin[o] = sin(a[l]);
        Cos[o] = cos(a[l]);
        return;
    }
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid + 1, r);
    maintain(o);
}

inline void pushdown(int o) {
    if (!tag[o]) return;
    double sinx = sin(tag[o]), cosx = cos(tag[o]);
    maintain(o << 1, sinx, cosx);
    maintain(o << 1 | 1, sinx, cosx);
    tag[o << 1] += tag[o];
    tag[o << 1 | 1] += tag[o];
    tag[o] = 0;
}

int k;
double sink, cosk;
void update(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) {
        maintain(o, sink, cosk);
        tag[o] += k;
        return;
    }
    pushdown(o);
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x, y);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y);
    maintain(o);
}

double query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return Sin[o];
    pushdown(o);
    double res = 0;
    int mid = l + r >> 1;
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr);
    if (mid < qr) res += query(o << 1 | 1, mid + 1, r, ql, qr);
    return res;
}

int main(void) {
    n = read();
    for (int i = 1; i <= n; ++i) a[i] = read();
    build(1, 1, n);
    m = read();
    while (m--) {
        int op = read();
        if (op == 1) {
            int l = read(), r = read();
            k = read();
            sink = sin(k), cosk = cos(k);
            update(1, 1, n, l, r);
        } else {
            int l = read(), r = read();
            printf("%.1lf\n", query(1, 1, n, l, r));
        }
    }
    return 0;
}

「Wdsr-2.7」文文的摄影布置

Portal.

观察条件 Ai+Akmin{Bj},i<j<kA_i+A_k-\min\{B_j\},i<j<k,我们在线段树的节点中维护 AA 的最大值和 BB 的最小值,以及区间答案 ansans

现在难就难在满足线段树的“区间可加性”,也就是如何从左右儿子合并出当前节点的答案。min{Bj}\min\{B_j\} 可以当成一个值,就是区间 BB 最小值。肯定可以三个数全从左子节点或右子节点过来,也可以两个数从一个节点过来,一个数从另一个节点过来。这样的话,我们记 lmaxlmax 代表 AiBjA_i-B_j 的最大值,rmaxrmax 代表 AkBjA_k-B_j 的最大值。这两个可以简单维护,要么从左右节点单独过来,要么两个下标在不同的区间,而且由于 i<j<ki<j<k,所以顺序一定。这样 ansans 就要么是左子节点的 lmaxlmax 和右子节点的 amaxamax 合并过来,要么是从右子节点的 rmaxrmax 和左子节点的 lmaxlmax 合并过来。

那么 lmaxlmaxrmaxrmax 呢?大致同理,要么都在叶子节点,要么跨区间,跨区间的时候就是通过维护的 A,BA,B 值来计算即可。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;

int n, m;
int a[500005], b[500005];

struct Node {
    int amax, bmin;
    int lmax, rmax, ans;
} T[2000005];

inline Node merge(Node a, Node b) {
    Node ans;
    ans.amax = max(a.amax, b.amax);
    ans.bmin = min(a.bmin, b.bmin);
    ans.lmax = max({a.lmax, b.lmax, a.amax - b.bmin});
    ans.rmax = max({a.rmax, b.rmax, b.amax - a.bmin});
    ans.ans = max({a.ans, b.ans, a.amax + b.rmax, a.lmax + b.amax});
    return ans;
}

void build(int o, int l, int r) {
    T[o].lmax = T[o].rmax = T[o].ans = -INF; // 初始什么都没有,是负无穷
    if (l == r) return T[o].amax = a[l], T[o].bmin = b[l], void();
    int mid = l + r >> 1, ls = o << 1, rs = ls | 1;
    build(ls, l, mid); build(rs, mid + 1, r);
    T[o] = merge(T[ls], T[rs]);
}

void update(int o, int l, int r, int x, int k) {
    if (l == r) return T[o].amax = a[l], T[o].bmin = b[l], void();
    int mid = l + r >> 1, ls = o << 1, rs = ls | 1;
    if (x <= mid) update(ls, l, mid, x, k);
    else update(rs, mid + 1, r, x, k);
    T[o] = merge(T[ls], T[rs]);
}

Node query(int o, int l, int r, int ql, int qr) {
    if (ql <= l && r <= qr) return T[o];
    int mid = l + r >> 1, ls = o << 1, rs = ls | 1;
    if (qr <= mid) return query(ls, l, mid, ql, qr);
    if (mid < ql) return query(rs, mid + 1, r, ql, qr);
    return merge(query(ls, l, mid, ql, qr), query(rs, mid + 1, r, ql, qr));
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    for (int i = 1; i <= n; ++i) scanf("%d", b + i);
    build(1, 1, n);
    while (m--) {
        int op, x, y;
        scanf("%d%d%d", &op, &x, &y);
        if (op == 1) {
            a[x] = y;
            update(1, 1, n, x, y);
        } else if (op == 2) {
            b[x] = y;
            update(1, 1, n, x, y);
        } else printf("%d\n", query(1, 1, n, x, y).ans);
    }
    return 0;
}

[NOIP2016 提高组] 蚯蚓

Portal.

蚯蚓长度增加这一事我们用一个延迟标记 deltadelta 完成,然后使用三个队列模拟优先队列(因为分裂越晚的蚯蚓长度只能更短)。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int INF = 1e9;

struct Queue {
    int Q[7000005], L = 1, R = 0;
    inline void push(int x) { Q[++R] = x; }
    inline void pop(void) { ++L; }
    inline int front(void) { return L <= R ? Q[L] : -INF; }
    inline bool empty(void) { return L > R; }
} A, B, C;

int n, m, q, u, v, t, delta = 0;
int a[100005], ans[7100005], tot = 0;

int main(void)
{
    scanf("%d%d%d%d%d%d", &n, &m, &q, &u, &v, &t);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    sort(a + 1, a + n + 1);
    for (int i = n; i >= 1; --i) A.push(a[i]);
    for (int i = 1; i <= m; ++i, delta += q) {
        int x;
        if (A.front() >= B.front() && A.front() >= C.front()) x = A.front(), A.pop();
        else if (B.front() >= A.front() && B.front() >= C.front()) x = B.front(), B.pop();
        else x = C.front(), C.pop();
        x += delta;
        if (i % t == 0) printf("%d ", x);
        int y = 1ll * u * x / v;
        B.push(y - delta - q); C.push(x - y - delta - q);
    }
    putchar('\n');
    for (int i = A.L; i <= A.R; ++i) ans[++tot] = A.Q[i];
    for (int i = B.L; i <= B.R; ++i) ans[++tot] = B.Q[i];
    for (int i = C.L; i <= C.R; ++i) ans[++tot] = C.Q[i];
    sort(ans + 1, ans + tot + 1, greater<int>());
    for (int i = 1; i <= tot; ++i)
        if (i % t == 0) printf("%d ", ans[i] + delta);
    putchar('\n');
    return 0;
}

[SDOI2009] HH 的项链

Portal.

显然同一种类只有最右面的会有用。将询问按照右端点排序,然后可以让前缀和不断向右扩展,方便查询,使用树状数组维护前缀和即可。

查看代码
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)

using namespace std;

struct Question {
    int l, r, id;
    bool operator < (const Question &a) const {
        return r < a.r;
    }
} q[1000005];

int n, m, C[1000005];
int a[1000005], ans[1000005];
int last[1000005];
void add(int x, int k) {
    for (; x <= n; x += lowbit(x)) 
        C[x] += k;
}
int sum(int x) {
    int res = 0;
    for (; x; x -= lowbit(x)) res += C[x];
    return res;
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    scanf("%d", &m);
    for (int i = 1; i <= m; ++i)
        scanf("%d%d", &q[i].l, &q[i].r), q[i].id = i; 
    sort(q + 1, q + m + 1);
    int r = 1;
    for (int i = 1; i <= m; ++i) {
        while (r <= q[i].r) {
            if (last[a[r]]) add(last[a[r]], -1);
            add(last[a[r]] = r, 1);
            ++r;
        }
        ans[q[i].id] = sum(q[i].r) - sum(q[i].l - 1);
    }
    for (int i = 1; i <= m; ++i)
        printf("%d\n", ans[i]);
    return 0;
}

[GZOI2017] 配对统计

Portal.

挖掘配对的性质,发现将配对的数排序后一个数的配对只能是它左边第一个或者是它右边第一个。将询问按照右端点排序,然后使用双指针加树状数组来维护当前询问的答案即可。

查看代码
#include <bits/stdc++.h>
#define lowbit(x) ((x) & (-(x)))
#define X first
#define Y second

using namespace std;
typedef long long i64;
typedef pair<int, int> pii;

int n, m, tot = 0;
pii b[600005];
bool cmp(pii a, pii b) {
    if (a.Y != b.Y) return a.Y < b.Y;
    return a.X < b.X;
}

struct Node {
    int val, pos;
    bool operator < (Node &a) const {
        return val < a.val;
    }
}a[300005];

struct Question {
    int l, r;
    int pos, ans;
    bool operator < (const Question &a) const {
        if (r != a.r) return r < a.r;
        return l < a.l;
    }
}Q[300005];

// ============ Fenwick Tree ============
int C[300005];
void update(int x) {
    while (x <= n) {
        C[x]++;
        x += lowbit(x);
    }
}
int query(int x) {
    int res = 0;
    while (x) {
        res += C[x];
        x -= lowbit(x);
    }
    return res;
}

void add(int l, int r) {
    if (l > r) swap(l, r);
    ++tot;
    b[tot].X = l, b[tot].Y = r;    
}

int main(void)
{
    scanf("%d%d", &n, &m);
    if (n == 1) return puts("0"), 0;
    for (int i = 1; i <= n; ++i)
        scanf("%d", &a[i].val), a[i].pos = i;
    sort(a + 1, a + n + 1);
    
    add(a[1].pos, a[2].pos);
    add(a[n - 1].pos, a[n].pos);
    for (int i = 2; i < n; ++i) {
        int l = a[i].val - a[i - 1].val, r = a[i + 1].val - a[i].val;
        if (l == r) add(a[i - 1].pos, a[i].pos), add(a[i].pos, a[i + 1].pos);
        else if (l < r) add(a[i - 1].pos, a[i].pos);
        else add(a[i].pos, a[i + 1].pos);
    }
    sort(b + 1, b + tot + 1, cmp);
    
    for (int i = 1; i <= m; ++i) {
        scanf("%d%d", &Q[i].l, &Q[i].r);
        Q[i].pos = i;
    }
    sort(Q + 1, Q + m + 1);
    
    i64 ans = 0;
    for (int i = 1, j = 0; i <= m; ++i) {
        while (j < tot && b[j + 1].Y <= Q[i].r) {
            ++j;
            update(b[j].X);
        }
        ans += 1ll * Q[i].pos * (j - query(Q[i].l - 1));
    }
    
    printf("%lld\n", ans);
    return 0;
}

「Wdsr-3」令人感伤的红雨

Portal.

实际上 Ω(l,r)=max{0,lA(1,r)}\Omega(l,r)=\max\{0,l-A(1,r)\},因此考虑如何维护前缀 AA。设所有“最值点”为 bb,前缀加会导致一些最值点消失,并查集维护每个位置所对应的最值点,链表维护最值点的存在情况即可。

查看代码
#include <bits/stdc++.h>
using namespace std; 

int n, m; 
int a[6000005], fa[6000005], nxt[6000005], b[6000005];

int find(int x) {
    if (fa[x] == x) return x; 
    return fa[x] = find(fa[x]); 
}

int main(void) {
    scanf("%d%d", &n, &m); memset(b, -1, sizeof b); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i), fa[i] = i, nxt[i] = n + 1; 
    for (int i = 2, p = 1; i <= n; ++i) {
        if (a[i] >= a[p]) nxt[p] = i, b[p] = a[i] - a[p], p = i; 
        else fa[i] = p; 
    }
    while (m--) {
        int op, x, y; scanf("%d%d%d", &op, &x, &y); 
        if (op == 1) {
            int t = find(x); b[t] -= y; 
            while (nxt[t] <= n && b[t] < 0) {
                b[t] += b[nxt[t]]; fa[nxt[t]] = t; 
                nxt[t] = nxt[nxt[t]]; 
            }
        } else printf("%d\n", max(0, x - find(y))); 
    }
    return 0; 
}

技巧性问题

这里是线段树的一些经典应用。

[Luogu P4145] 上帝造题的七分钟 2 / 花神游历各国

Portal.

懒标记?

如果您能提出质疑,那么笔者为您点赞。如果不能,你可能要重新去看《线段树的本质》一节(笔者要被扣工资了

要注意到的是,如果使用延迟标记,那么当前的区间和是无法维护的。因为它不像区间加区间 sin 和这种东西可以进行拆解,每个数开平方后区间的和无法简单维护。

但是区间开方这种东西,很容易就开到 11 了。然而对着 11 开方是没有用的。所以如果区间的最大值是 11,那么区间开方这种操作就没必要进行了。

我们直接使用线段树,但是不需要延迟标记,维护到叶子节点为止。如果一个区间已经比 11 小,那么就不用维护了。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int n, m;
i64 a[100005];
i64 sum[400005], maxx[400005];

inline void maintain(int o)
{
    sum[o] = sum[o << 1] + sum[o << 1 | 1];
    maxx[o] = max(maxx[o << 1], maxx[o << 1 | 1]);
}

void build(int o, int l, int r)
{
    if (l == r) return sum[o] = maxx[o] = a[l], void();
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid + 1, r);
    maintain(o);
}

void update(int o, int l, int r, int x, int y)
{
    if (l == r) // 叶子节点重新维护
    {
        sum[o] = sqrt(sum[o]);
        maxx[o] = sqrt(maxx[o]);
        return;
    }
    int mid = l + r >> 1;
    if (x <= mid && maxx[o << 1] > 1) update(o << 1, l, mid, x, y); // 最大值大于 1 才修改
    if (mid < y && maxx[o << 1 | 1] > 1) update(o << 1 | 1, mid + 1, r, x, y);
    maintain(o);
}

i64 query(int o, int l, int r, int ql, int qr)
{
    if (ql <= l && r <= qr) return sum[o];
    int mid = l + r >> 1;
    i64 res = 0;
    if (ql <= mid) res += query(o << 1, l, mid, ql, qr);
    if (mid < qr) res += query(o << 1 | 1, mid + 1, r, ql, qr);
    return res;
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i) scanf("%lld", a + i);
    scanf("%d", &m);
    build(1, 1, n);
    while (m--)
    {
        int k, l, r;
        scanf("%d%d%d", &k, &l, &r);
        if (l > r) swap(l, r);
        if (k == 0) update(1, 1, n, l, r);
        else printf("%lld\n", query(1, 1, n, l, r));
    }
    return 0;
}

[TJOI2018] 数学计算

Portal.

这不是模拟吗(

但是不行,我们知道除法是没有随时取模性质的,就算算逆元,也没有保证互质,逆元不一定有。

注意到最多除一次,以时间建立一棵线段树,根节点维护的是当前的 xx 值。对于一个乘操作,我们就将当前询问的编号乘上 xx,对于除法,我们就将这个编号改为 11

查看代码
#include <iostream>
#include <cstdio>

#define i64 long long

using namespace std;

int Q, M;
int T[400005];

void build(int o, int l, int r) {
    T[o] = 1;
    if (l == r) return;
    int mid = l + r >> 1;
    build(o << 1, l, mid);
    build(o << 1 | 1, mid + 1, r);
}

void update(int o, int l, int r, int x, int k) {
    if (l == r) {
        T[o] = (k == 0) ? 1 : k;
        return;
    }
    int mid = l + r >> 1, ls = o << 1, rs = ls | 1;
    if (x <= mid) update(ls, l, mid, x, k);
    else update(rs, mid + 1, r, x, k);
    T[o] = (i64)T[ls] * T[rs] % M;
}

int main(void) {
    int TT;
    scanf("%d", &TT);
    while (TT--) {
        scanf("%d%d", &Q, &M);
        build(1, 1, Q);
        for (int i = 1; i <= Q; ++i) {
            int op, x;
            scanf("%d%d", &op, &x);
            if (op == 1) update(1, 1, Q, i, x);
            else update(1, 1, Q, x, 0);
            printf("%d\n", T[1] % M);
        }
    }
    return 0;
}
这种基于时间的操作非常常见,请读者一定要熟记。

[SHOI2015] 脑洞治疗仪

Portal.

对于操作二,考虑线段树上二分:需要先满足填的区间,然后从左子树开始尝试填满(见代码)。

查看代码
#include <bits/stdc++.h>
using namespace std;

struct Node {
    int sum, lmax, rmax, dat, len;
} T[800005];
int tag[800005], len[800005];

int n, m;
Node hb(const Node &a, const Node &b) {
    Node c;
    c.sum = a.sum + b.sum; c.len = a.len + b.len;
    c.lmax = (a.lmax == a.len ? a.len + b.lmax : a.lmax);
    c.rmax = (b.rmax == b.len ? b.len + a.rmax : b.rmax);
    c.dat = max({a.dat, b.dat, a.rmax + b.lmax});
    return c;
}
void build(int o, int l, int r) {
    tag[o] = -1; T[o].len = r - l + 1;
    if (l == r) return T[o].sum = 1, void();
    int mid = l + r >> 1;
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    T[o] = hb(T[o << 1], T[o << 1 | 1]);
}
void maketag(int o, int l, int r, int k) {
    T[o].sum = k * (r - l + 1);
    T[o].lmax = T[o].rmax = T[o].dat = (1 - k) * (r - l + 1);
    tag[o] = k;
}
void pushdown(int o, int l, int r) {
    if (tag[o] == -1) return;
    int mid = l + r >> 1;
    maketag(o << 1, l, mid, tag[o]);
    maketag(o << 1 | 1, mid + 1, r, tag[o]);
    tag[o] = -1;
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) return maketag(o, l, r, k);
    int mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k);
    T[o] = hb(T[o << 1], T[o << 1 | 1]);
}
Node query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1; pushdown(o, l, r);
    if (y <= mid) return query(o << 1, l, mid, x, y);
    if (mid < x) return query(o << 1 | 1, mid + 1, r, x, y);
    return hb(query(o << 1, l, mid, x, y), query(o << 1 | 1, mid + 1, r, x, y));
}
int cont(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o].sum;
    int res = 0, mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) res += cont(o << 1, l, mid, x, y);
    if (mid < y) res += cont(o << 1 | 1, mid + 1, r, x, y);
    return res;
}
int dag(int o, int l, int r, int x, int y, int k) {
    if (k == 0) return 0;
    if (x <= l && r <= y && T[o].len - T[o].sum <= k) {
        int t = T[o].len - T[o].sum;
        maketag(o, l, r, 1);
        return k - t;
    }
    pushdown(o, l, r); int ans = 0, mid = l + r >> 1;
    if (y <= mid) ans = dag(o << 1, l, mid, x, y, k);
    else if (mid < x) ans = dag(o << 1 | 1, mid + 1, r, x, y, k);
    else ans = dag(o << 1 | 1, mid + 1, r, x, y, dag(o << 1, l, mid, x, y, k));
    return T[o] = hb(T[o << 1], T[o << 1 | 1]), ans;
}

int main(void) {
    scanf("%d%d", &n, &m); build(1, 1, n);
    while (m--) {
        int op, l, r, l1, r1; scanf("%d%d%d", &op, &l, &r);
        if (op == 0) update(1, 1, n, l, r, 0);
        else if (op == 1) {
            scanf("%d%d", &l1, &r1);
            int x = cont(1, 1, n, l, r);
            if (x == 0) continue; update(1, 1, n, l, r, 0);
            dag(1, 1, n, l1, r1, x);
        } else printf("%d\n", query(1, 1, n, l, r).dat);
    }
    return 0;
}

[THUSC2015] 平方运算

Portal.

模意义下区间平方是存在循环节的,多次平方后必定会陷入循环。那么线段树直接暴力维护,提前预处理出每个数的循环,进入了循环节之后就可以开始打标记,维护一个偏移量代表循环到哪里即可。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int gcd(int x, int y) { if (y == 0) return x; return gcd(y, x % y); }
int lcm(int x, int y) { return x / gcd(x, y) * y; }

int n, m, l = 1, P, a[100005]; 
int vis[10005], p[10005]; 
bool lp[400005]; int now[400005], tag[400005]; 
i64 T[400005][60]; 

inline void chk(int o) {
    if (p[T[o][0]]) {
        for (int i = 1; i < l; ++i) T[o][i] = T[o][i - 1] * T[o][i - 1] % P; 
        lp[o] = 1; 
    }
}
inline void pushup(int o) {
    lp[o] = (lp[o << 1] && lp[o << 1 | 1]); now[o] = 0; 
    if (!lp[o]) T[o][0] = T[o << 1][now[o << 1]] + T[o << 1 | 1][now[o << 1 | 1]]; 
    else {
        int lx = now[o << 1], rx = now[o << 1 | 1]; 
        for (int i = 0; i < l; ++i) {
            T[o][i] = T[o << 1][lx] + T[o << 1 | 1][rx]; 
            lx = (lx + 1) % l, rx = (rx + 1) % l; 
        }
    }
}

void build(int o, int l, int r) {
    if (l == r) return T[o][0] = a[l], chk(o); 
    int mid = l + r >> 1; build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
    pushup(o); 
}
inline void maketag(int o, int k) {
    tag[o] = (tag[o] + k) % l; 
    now[o] = (now[o] + k) % l; 
}
inline void pushdown(int o) {
    if (!tag[o]) return; 
    maketag(o << 1, tag[o]); maketag(o << 1 | 1, tag[o]); 
    tag[o] = 0; 
}
void update(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y && lp[o]) return maketag(o, 1); 
    if (l == r) return T[o][0] = T[o][0] * T[o][0] % P, chk(o); 
    pushdown(o); int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, y); 
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y); 
    pushup(o); 
}
i64 query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o][now[o]]; 
    pushdown(o); int mid = l + r >> 1; i64 res = 0; 
    if (x <= mid) res += query(o << 1, l, mid, x, y); 
    if (mid < y) res += query(o << 1 | 1, mid + 1, r, x, y); 
    return res; 
}

void findloop(int x) {
    for (int i = 1, y = x;; y = y * y % P, ++i) 
        if (vis[y]) { p[y] = i - vis[y]; break; }
        else vis[y] = i; 
    for (int y = x; vis[y]; y = y * y % P) vis[y] = 0; 
}

int main(void) {
    scanf("%d%d%d", &n, &m, &P); 
    for (int i = 0; i < P; ++i) findloop(i); 
    for (int i = 0; i < P; ++i) if (p[i]) l = lcm(l, p[i]); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i); 
    build(1, 1, n); 
    while (m--) {
        int op, l, r; scanf("%d%d%d", &op, &l, &r); 
        if (op == 1) update(1, 1, n, l, r); 
        else printf("%lld\n", query(1, 1, n, l, r)); 
    }
    return 0; 
}

综合应用

这里是一些简单的综合题。

[HEOI2016] 排序

Portal.

给定一个 11nn 的排列,进行 mm 次操作,可以是将给定的区间升序或者降序排序。问最后第 qq 个位置上的数字。

先来考虑一个简单的问题,01 排序怎么做?维护区间 01 的数量,排序的时候直接将后面的改为 1,前面的改为 0,可以使用线段树完成。

现在考虑怎么求解原问题。如果将所有 x\ge x 的数都设置为 11<x<x 的都设置为 00,那么照样求解,如果第 qq 个位置是 11 就说明 qq 代表的数一定 x\ge x。最终二分出的结果就是答案。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int MOD = 201314;

int n, m, q;
int a[100005], b[100005];
int op[100005], l[100005], r[100005]; 
int T[400005], tag[400005];

void build(int o, int l, int r) {
    tag[o] = -1;
    if (l == r) return T[o] = (b[l] == 1), void();
    int mid = l + r >> 1; 
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
inline void pushdown(int o, int l, int r) {
    if (tag[o] == -1) return;
    int mid = l + r >> 1; 
    tag[o << 1] = tag[o << 1 | 1] = tag[o];
    T[o << 1] = tag[o] * (mid-l+1); T[o << 1 | 1] = tag[o] * (r-mid);
    tag[o] = -1;
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x > y) return;
    if (x <= l && r <= y) return tag[o] = k, T[o] = k * (r-l+1), void();
    int mid = l + r >> 1; pushdown(o, l, r);
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k); 
    T[o] = T[o << 1] + T[o << 1 | 1];
}
int query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1, res = 0; pushdown(o, l, r);
    if (x <= mid) res += query(o << 1, l, mid, x, y);
    if (mid < y) res += query(o << 1 | 1, mid + 1, r, x, y);
    return res;
}

bool P(int x) {
    for (int i = 1; i <= n; ++i) b[i] = (a[i] >= x ? 1 : 0);
    build(1, 1, n);
    for (int i = 1; i <= m; ++i) {
        int k = query(1, 1, n, l[i], r[i]);
        if (op[i] == 0) {
            update(1, 1, n, l[i], r[i] - k, 0);
            update(1, 1, n, r[i] - k + 1, r[i], 1);
        } else {
            update(1, 1, n, l[i], l[i] + k - 1, 1);
            update(1, 1, n, l[i] + k, r[i], 0);
        }
    }
    return query(1, 1, n, q, q) == 1;
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i);
    for (int i = 1; i <= m; ++i) scanf("%d%d%d", op + i, l + i, r + i);
    scanf("%d", &q);
    int L = 0, R = n + 1;
    while (L + 1 != R) {
        int mid = L + R >> 1;
        if (P(mid)) L = mid;
        else R = mid;
    }
    printf("%d\n", L);
    return 0;
}

[Luogu P5278] 算术天才⑨与等差数列

Portal.

发现条件非常严苛,因此可以考虑哈希之类的方法,这里不做赘述。

一段区间可以重排为等差数列,当且仅当满足(d=0d=0 先特判掉):

  • maxmin=d×(len1)\max -\min =d\times (len-1)
  • gcdi=lr1(ai+1ai)=d\gcd_{i=l}^{r-1}(a_{i+1}-a_i)=d
  • 序列中没有重复的元素。

用线段树维护即可。第三条可以使用 set、map 维护一个数最左边的出现位置,然后用线段树维护这个值的最小值,如果这个数小于 ll,那么一定没有重复元素。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int INF = 0x3f3f3f3f;

int gcd(int x, int y) {
    if (y == 0) return x;
    return gcd(y, x % y);
}

int n, m;
int a[300005], c[300005], pre[300005];
unordered_map<int, set<int>> mp;

struct Node {
    int mx, mn, mx_pre;
    friend Node operator+ (const Node &a, const Node &b) {
        Node c;
        c.mx = max(a.mx, b.mx);
        c.mn = min(a.mn, b.mn);
        c.mx_pre = max(a.mx_pre, b.mx_pre);
        return c;
    }
} T[1200005];
void build(int o, int l, int r) {
    if (l == r) {
        T[o].mx = T[o].mn = a[l]; T[o].mx_pre = pre[l];
        return;
    }
    int mid = l + r >> 1;
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
void update(int o, int l, int r, int x) {
    if (l == r) {
        T[o].mx = T[o].mn = a[l]; T[o].mx_pre = pre[l];
        return;
    }
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x);
    else update(o << 1 | 1, mid + 1, r, x);
    T[o] = T[o << 1] + T[o << 1 | 1];
}
Node query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int mid = l + r >> 1; Node res = {-1, INF, 0};
    if (x <= mid) res = res + query(o << 1, l, mid, x, y);
    if (mid < y) res = res + query(o << 1 | 1, mid + 1, r, x, y);
    return res;
}
int tt[1200005];
void buildx(int o, int l, int r) {
    if (l == r) return tt[o] = c[l], void();
    int mid = l + r >> 1;
    buildx(o << 1, l, mid); buildx(o << 1 | 1, mid + 1, r);
    tt[o] = gcd(tt[o << 1], tt[o << 1 | 1]);
}
void updatex(int o, int l, int r, int x) {
    if (l == r) return tt[o] = c[x], void();
    int mid = l + r >> 1;
    if (x <= mid) updatex(o << 1, l, mid, x);
    else updatex(o << 1 | 1, mid + 1, r, x);
    tt[o] = gcd(tt[o << 1], tt[o << 1 | 1]);
}
int queryx(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return tt[o];
    int mid = l + r >> 1, res = 0;
    if (x <= mid) res = gcd(res, queryx(o << 1, l, mid, x, y));
    if (mid < y) res = gcd(res, queryx(o << 1 | 1, mid + 1, r, x, y));
    return res;
}

bool solve(int l, int r, int k) {
    if (l == r) return true;
    Node t = query(1, 1, n, l, r);
    int g = queryx(1, 1, n - 1, l, r - 1);
    if (t.mx - t.mn != 1ll * k * (r - l)) return false;
    if (k && t.mx_pre >= l) return false;
    if (g != k) return false;
    return true;
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= n; ++i) {
        scanf("%d", a + i);
        if (mp[a[i]].empty()) pre[i] = -1;
        else {
            auto it = mp[a[i]].end(); --it;
            pre[i] = *it;
        }
        mp[a[i]].insert(i);
    }
    for (int i = 1; i < n; ++i) c[i] = abs(a[i + 1] - a[i]);
    build(1, 1, n); if (n - 1) buildx(1, 1, n - 1); 
    int cnt = 0, op, x, y, k;
    while (m--) {
        scanf("%d%d%d", &op, &x, &y); x ^= cnt; y ^= cnt;
        if (op == 1) {
            auto it = mp[a[x]].find(x); ++it;
            if (it != mp[a[x]].end()) pre[*it] = pre[x], update(1, 1, n, *it);
            mp[a[x]].erase(x); a[x] = y; mp[a[x]].insert(x);
            it = mp[a[x]].upper_bound(x);
            if (it != mp[a[x]].end()) pre[*it] = x, update(1, 1, n, *it);
            --it;
            if (it != mp[a[x]].begin()) --it, pre[x] = *it;
            else pre[x] = -1;
            c[x] = abs(a[x + 1] - a[x]); c[x - 1] = abs(a[x] - a[x - 1]);
            update(1, 1, n, x); if (x < n) updatex(1, 1, n - 1, x);
            if (x - 1) updatex(1, 1, n - 1, x - 1);
        } else {
            scanf("%d", &k); k ^= cnt;
            if (solve(x, y, k)) puts("Yes"), ++cnt;
            else puts("No");
        }
    }
    return 0;
}

Portal.

定义一下两种关系:

  • “补”表示与数 xx 相加为 ww
  • “等”表示与数 xx 相等。

记录每个数的补前驱,然后用线段树查询区间内补前驱的最大编号?当然可以,但是 1 5 5 5 5 5 这种修改 11 就可以直接炸掉:后面所有数的补前驱都将会变动。

令一个数的补前驱可以被记录,当且仅当它补前驱的位置在它等前驱右边,否则记录为 00。不难发现这样依次修改最多只会影响 55 个数:自身、原来 axa_x 的补后驱和等后驱、yy 的补后驱和等后驱。使用 set 加线段树维护即可。

查看代码
#include <bits/stdc++.h>
using namespace std;

int n, m, w;
int a[500005], T[2000005], pre[2000005];
set<int> s[500005];

void build(int o, int l, int r) {
    if (l == r) return T[o] = pre[l], void();
    int mid = l + r >> 1;
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r);
    T[o] = max(T[o << 1], T[o << 1 | 1]);
}
void update(int o, int l, int r, int x) {
    if (l == r) return T[o] = pre[l], void();
    int mid = l + r >> 1;
    if (x <= mid) update(o << 1, l, mid, x);
    else update(o << 1 | 1, mid + 1, r, x);
    T[o] = max(T[o << 1], T[o << 1 | 1]);
}
int query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    int res = -1, mid = l + r >> 1;
    if (x <= mid) res = max(res, query(o << 1, l, mid, x, y));
    if (mid < y) res = max(res, query(o << 1 | 1, mid + 1, r, x, y));
    return res;
}

void calc(int x, int y) { // a[x] 改成 y,修改 x 的前驱
    auto i = s[y].find(x);
    if (i != s[y].begin()) {
        --i; // i 是 x 的等前驱
        auto j = s[w - y].lower_bound(x);
        if (j == s[w - y].begin()) pre[x] = 0;
        else {
            --j; // j 是 x 的补前驱
            if (*j >= *i) pre[x] = *j;
            else pre[x] = 0;
        }
    } else { // 没有等前驱
        auto j = s[w - y].lower_bound(x);
        if (j == s[w - y].begin()) pre[x] = 0;
        else pre[x] = *(--j);
    }
    update(1, 1, n, x);
}

int main(void) {
    scanf("%d%d%d", &n, &m, &w);
    for (int i = 1; i <= n; ++i) {
        scanf("%d", a + i);
        if (s[w - a[i]].size()) {
            int it = *(--s[w - a[i]].end());
            if (s[a[i]].empty() || it >= *(--s[a[i]].end())) pre[i] = it;
        }
        s[a[i]].insert(i);
    }
    build(1, 1, n); int cnt = 0;
    while (m--) {
        int op, x, y; scanf("%d%d%d", &op, &x, &y);
        if (op == 1) {
            auto k = s[a[x]].find(x); ++k; // k 为 a[x] 的等后驱
            auto l = s[w - a[x]].upper_bound(x); // l 为 a[x] 的补后驱
            s[a[x]].erase(x); s[y].insert(x);
            if (k != s[a[x]].end()) calc(*k, a[x]);
            if (l != s[w - a[x]].end()) calc(*l, w - a[x]);
            
            a[x] = y;
            k = s[a[x]].find(x); ++k; // k 为 y 的等后驱
            l = s[w - a[x]].upper_bound(x); // l 为 y 的补后驱
            if (k != s[a[x]].end()) calc(*k, a[x]);
            if (l != s[w - a[x]].end()) calc(*l, w - a[x]);

            calc(x, y);
        } else {
            x ^= cnt; y ^= cnt;
            if (query(1, 1, n, x, y) >= x) puts("Yes"), ++cnt;
            else puts("No");
        }
    }
    return 0;
}

[BJOI2019] 删数

Portal.

xx 的出现次数为 tt,那么其能覆盖 [xt+1,x][x-t+1,x] 的区间。答案是 [1,n][1,n] 中未被覆盖的个数。这样单点修改只会让两个数的出现次数更改,区间平移可以看作询问区间的平移,每次的移动距离也只有 11

线段树维护区间 00 的个数。由于有区间加的延迟标记,因此记录区间最小值和最小值的出现次数,可以在区间加的时候方便统计。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 450005;

int n, m, P = 150001; 
int a[150005], buc[N + 5]; 
int tag[N * 4 + 5], mn[N * 4 + 5], cnt[N * 4 + 5], ans[N * 4 + 5]; 

inline void pushup(int o) {
    mn[o] = min(mn[o << 1], mn[o << 1 | 1]); 
    cnt[o] = (mn[o] == mn[o << 1] ? cnt[o << 1] : 0) + (mn[o] == mn[o << 1 | 1] ? cnt[o << 1 | 1] : 0); 
    ans[o] = ans[o << 1] + ans[o << 1 | 1]; 
}
inline void maketag(int o, int k) {
    mn[o] += k; 
    ans[o] = (mn[o] == 0 ? cnt[o] : 0); 
    tag[o] += k; 
}
inline void pushdown(int o) {
    if (!tag[o]) return; 
    maketag(o << 1, tag[o]); maketag(o << 1 | 1, tag[o]);
    tag[o] = 0; 
}
void build(int o, int l, int r) {
    if (l == r) return ans[o] = cnt[o] = 1, void(); 
    int mid = l + r >> 1; 
    build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
    pushup(o); 
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) return maketag(o, k); 
    pushdown(o); int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, y, k);
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k);
    pushup(o); 
}
int query(int o, int l, int r, int x, int y) { // 查询 [x, y] 当中没有被覆盖的个数
    if (x <= l && r <= y) return ans[o]; 
    pushdown(o); int mid = l + r >> 1, res = 0; 
    if (x <= mid) res += query(o << 1, l, mid, x, y); 
    if (mid < y) res += query(o << 1 | 1, mid + 1, r, x, y); 
    return res; 
}
void change(int x, int c) {
    int k = x - buc[x] + 1 - (c > 0); 
    update(1, 1, N, k, k, c); 
    buc[x] += c; 
}

int main(void) {
    // 询问区间为 [1 + P, n + P]
    scanf("%d%d", &n, &m); build(1, 1, N); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i), change(a[i] += P, 1);
    while (m--) {
        int p, x; scanf("%d%d", &p, &x); 
        if (p > 0) { // 单点修改
            if (a[p] <= n + P) change(a[p], -1); 
            else --buc[a[p]]; 
            a[p] = x + P; 
            if (a[p] <= n + P) change(a[p], 1); 
            else ++buc[a[p]]; 
        } else {
            if (x > 0) { // 询问区间向左平移
                int pos = n + P; 
                if (buc[pos]) update(1, 1, N, pos - buc[pos] + 1, pos, -1); 
                --P;
            } else {
                ++P; 
                int pos = n + P; 
                if (buc[pos]) update(1, 1, N, pos - buc[pos] + 1, pos, 1); 
            }
        }
        printf("%d\n", query(1, 1, N, 1 + P, n + P)); 
    }
    return 0;
}

[GDOI2014] 吃

Portal.

可以将 [1,l)(r,n][1,l)\cup (r,n] 拆开,因此一次询问就变成了在 [1,l),[l,r][1,l),[l,r] 中各选一个数。

高效维护这个问题十分困难,发现值域很小,直接处理出所有数的因数再做考虑。离线,按照右端点升序排序。如果 preaipre_{a_i} 存在,那么询问的 ll(preai,i](pre_{a_i},i] 的范围内出现时是可以更新到 aia_i 的,一个区间修改单点查询的线段树就可以完成。时间复杂度为 O(nnlogn)O(n\sqrt{n}\log n)

查看代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5; 

int n, m; 
int a[100005], pre[100005], nxt[100005]; 
vector<int> b[100005]; 
struct Query {
    int l, r, id; 
} Q[100005]; 
int ans[100005]; 
int T[400005], tag[400005]; 
inline void maketag(int o, int k) { 
    T[o] = max(T[o], k); tag[o] = max(tag[o], k); 
}
inline void pushdown(int o) {
    if (!tag[o]) return; 
    maketag(o << 1, tag[o]); maketag(o << 1 | 1, tag[o]); 
    tag[o] = 0; 
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) return maketag(o, k); 
    pushdown(o); int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, y, k); 
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k); 
    T[o] = max(T[o << 1], T[o << 1 | 1]); 
}
int query(int o, int l, int r, int x) {
    if (l == r) return T[o]; 
    pushdown(o); int mid = l + r >> 1; 
    if (x <= mid) return query(o << 1, l, mid, x); 
    return query(o << 1 | 1, mid + 1, r, x); 
}

int main(void) {
    for (int i = 1; i <= N; ++i) for (int j = i; j <= N; j += i) b[j].emplace_back(i); 
    scanf("%d", &n); 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i); 
    scanf("%d", &m); 
    for (int i = 1; i <= m; ++i) scanf("%d%d", &Q[i].l, &Q[i].r), Q[i].id = i; 

    sort(Q + 1, Q + m + 1, [&](auto a, auto b) { return a.r < b.r; }); 
    for (int i = 1, j = 1; i <= n; ++i) {
        for (int x : b[a[i]]) {
            if (pre[x]) update(1, 1, N, pre[x] + 1, i, x); // l 在这部分时可以有答案
            pre[x] = i; 
        }
        while (j <= m && Q[j].r == i) {
            ans[Q[j].id] = max(ans[Q[j].id], query(1, 1, N, Q[j].l)); 
            ++j; 
        }
    }

    memset(T, 0, sizeof T); memset(tag, 0, sizeof tag); 
    sort(Q + 1, Q + m + 1, [&](auto a, auto b) { return a.l > b.l; }); 
    for (int i = n, j = 1; i >= 1; --i) {
        for (int x : b[a[i]]) {
            if (nxt[x]) update(1, 1, N, i, nxt[x] - 1, x); 
            nxt[x] = i; 
        }
        while (j <= m && Q[j].l == i) {
            ans[Q[j].id] = max(ans[Q[j].id], query(1, 1, N, Q[j].r)); 
            ++j; 
        }
    }

    for (int i = 1; i <= m; ++i) printf("%d\n", ans[i]); 
    return 0;
}

[CTT2012] 序列操作

Portal.

由于 cc 很小,因此直接将答案记录在线段树内。

唯一困难的是区间加,发现它所增加的贡献并不直观。比如:

[a1,ai][a1+c,ai+c][a_1,\cdots a_i]\rightarrow [a_1+c,\cdots a_i+c]

然后把他们乘起来,并展开,可以发现其实有规律的:

fi(leni0)×c0×fi+(len(i1)1)×c1×fi1+f_i\leftarrow \binom{len-i}{0}\times c^0 \times f_i+\binom{len-(i-1)}{1}\times c^1 \times f_{i-1}+\cdots

这个东西可以线性计算,那么整体就是好维护的了。

查看代码
#include <bits/stdc++.h>
#define REV 2000000000
using namespace std;
const int P = 19940417; 

int n, q, C[50005][25]; 
int a[50005]; 
int addv[200005]; bool rev[200005]; 
struct Node {
    int c[21]; 
    friend Node operator+ (const Node &a, const Node &b) {
        Node c; memset(c.c, 0, sizeof c.c); 
        for (int i = 0; i <= 20; ++i)
            for (int j = 0; i + j <= 20; ++j)
                c.c[i + j] = (c.c[i + j] + 1ll * a.c[i] * b.c[j]) % P; 
        return c; 
    }
} T[200005];

void build(int o, int l, int r) {
    if (l == r) return T[o].c[0] = 1, T[o].c[1] = a[l] % P, void();
    int mid = l + r >> 1; build(o << 1, l, mid); build(o << 1 | 1, mid + 1, r); 
    T[o] = T[o << 1] + T[o << 1 | 1]; 
}
int tmp[21]; 
inline void maketag(int o, int l, int r, int k) {
    if (k == REV) {
        rev[o] ^= 1; addv[o] = P - addv[o]; 
        for (int i = 1; i <= 20; ++i) if (i & 1) T[o].c[i] = P - T[o].c[i]; 
        return; 
    }
    addv[o] = (addv[o] + k) % P; 
    for (int i = tmp[0] = 1; i <= 20; ++i) tmp[i] = 1ll * tmp[i - 1] * k % P; 
    
    for (int i = min(r - l + 1, 20); i; --i)
        for (int j = 0; j < i; ++j)
            T[o].c[i] = (T[o].c[i] + 1ll * T[o].c[j] * tmp[i - j] % P * C[r - l + 1 - j][i - j]) % P; 
}
inline void pushdown(int o, int l, int r) {
    int mid = l + r >> 1; 
    if (rev[o]) {
        maketag(o << 1, l, mid, REV); 
        maketag(o << 1 | 1, mid + 1, r, REV); 
        rev[o] = 0; 
    }
    if (addv[o]) {
        maketag(o << 1, l, mid, addv[o]); 
        maketag(o << 1 | 1, mid + 1, r, addv[o]); 
        addv[o] = 0;  
    }
}
void update(int o, int l, int r, int x, int y, int k) {
    if (x <= l && r <= y) return maketag(o, l, r, k); 
    pushdown(o, l, r); int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, y, k); 
    if (mid < y) update(o << 1 | 1, mid + 1, r, x, y, k); 
    T[o] = T[o << 1] + T[o << 1 | 1]; 
}
Node query(int o, int l, int r, int x, int y) {
    if (x <= l && r <= y) return T[o];
    pushdown(o, l, r); int mid = l + r >> 1; 
    if (y <= mid) return query(o << 1, l, mid, x, y); 
    if (mid < x) return query(o << 1 | 1, mid + 1, r, x, y); 
    return query(o << 1, l, mid, x, y) + query(o << 1 | 1, mid + 1, r, x, y); 
}

int main(void) {
    scanf("%d%d", &n, &q); 
    for (int i = 0; i <= n; ++i)
        for (int j = C[i][0] = 1; j <= min(i, 20); ++j)
            C[i][j] = (C[i - 1][j] + C[i - 1][j - 1]) % P; 
    for (int i = 1; i <= n; ++i) scanf("%d", a + i); 
    build(1, 1, n); char op[5]; int l, r, k; 
    while (q--) {
        scanf("%s%d%d", op, &l, &r); if (op[0] != 'R') scanf("%d", &k); 
        if (op[0] == 'I') update(1, 1, n, l, r, k); 
        else if (op[0] == 'R') update(1, 1, n, l, r, REV); 
        else printf("%d\n", (query(1, 1, n, l, r).c[k] % P + P) % P); 
    }
    return 0; 
}

[Ynoi2015] 纵使日薄西山

Portal.

珂朵莉想让你维护一个长度为 nn 的正整数序列 a1,a2,,ana_1,a_2,\ldots,a_n,支持修改序列中某个位置的值。

每次修改后问对序列重复进行以下操作,需要进行几次操作才能使序列变为全 00(询问后序列和询问前相同,不会变为全 00):

选出序列中最大值的出现位置,若有多个最大值则选位置标号最小的一个,设位置为 xx,则将 ax1,ax,ax+1a_{x-1},a_x,a_{x+1} 的值减 11,如果序列中存在小于 00 的数,则把对应的数改为 00

1n,q1051\leq n,q\leq 10^51xin1\leq x_i\leq n1ai,yi1091\leq a_i,y_i\leq 10^9

考虑哪些数可以被减。如果我们开始减 aia_i,那么它一定会一直减下去(因为左右两个永远都比它小)。

将原序列进行单调极长划分,发现对于每个极长单调区间,答案一定是所有奇数位置或者所有偶数位置的和。使用一个 set 存储所有的极长单调区间分割点(称为极值点,令一个极值点代表极长单调区间的结束),修改一个数时最多只会影响到五个极值点(修改一个极值点可能使它右边的极值点不存在,进而影响右边第二个极值点,左边同理),复杂度可以接受。

根据此维护即可,细节很多。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64;

int n, m, a[100005];
i64 ans;
set<int> s;

struct Fenwick {
    #define lowbit(x) (x & -x)
    i64 C[100005];
    void add(int x, int k) {
        for (; x <= n; x += lowbit(x)) C[x] += k;
    }
    i64 sum(int x) {
        i64 res = 0;
        for (; x; x -= lowbit(x)) res += C[x];
        return res;
    }
} T[2];

void calc(set<int>::iterator l, set<int>::iterator r, int op) {
    for (; l != r; --r) {
        auto it = r; --it;
        if (a[*r] > a[*it]) { // 当前极长单调区间不受前一个影响
            int p = *r & 1;
            ans += (T[p].sum(*r) - T[p].sum(*it)) * op;
        } else {
            int p = *it & 1;
            ans += (T[p].sum(*r - 1) - T[p].sum(*it)) * op;
            // 要看 r 这个位置有没有被修改
            auto lt = r, rt = r;
            if (lt != s.begin()) --lt;
            ++rt; if (rt == s.end()) --rt;
            // 没有被 it 修改,没有被后面一个极长单调子区间修改
            if ((*r - *lt) % 2 == 0 && (*rt - *r) % 2 == 0)
                ans += a[*r] * op;
        }
    }
    if (a[*l] >= a[*l + 1]) return; // 此时 l 自己修改自己
    auto lt = r, rt = r;
    if (lt != s.begin()) --lt;
    ++rt; if (rt == s.end()) --rt;
    if ((*r - *lt) % 2 == 0 && (*rt - *r) % 2 == 0) ans += a[*r] * op;
}
void check(int x) {
    if ((a[x - 1] < a[x]) == (a[x] < a[x + 1])) s.erase(x);
    else s.insert(x);
}

void update(int x, int y) {
    auto it = s.lower_bound(x), l = it, r = it;
    --l; if (l != s.begin()) --l;
    ++r; if (r != s.end()) ++r; if (r == s.end()) --r;
    calc(l, r, -1);
    T[x & 1].add(x, y - a[x]); a[x] = y;
    check(x); if (x > 1) check(x - 1); if (x < n) check(x + 1);
    calc(l, r, 1);
}

int main(void) {
    scanf("%d", &n); s.insert(0); s.insert(n + 1);
    for (int i = 1, x; i <= n; ++i) scanf("%d", &x), update(i, x);
    scanf("%d", &m);
    while (m--) {
        int x, y; scanf("%d%d", &x, &y);
        update(x, y); printf("%lld\n", ans);
    }
    return 0;
}

[Code+#1] Yazid 的新生舞会

Portal.

给定一个长度为 n(n5×105)n(n\le 5\times 10^5) 的序列,问其中有多少个子区间存在出现次数严格超过子区间长度一半的众数。

考虑枚举每个种类的数分别计算,设当前选中的数为 wwSiS_i 为前 ii 个数中 ww 的个数。

对于一段区间 [l+1,r][l+1,r](方便差分),满足条件时有 SrSl>rl(SrSl)2Srr>2SllS_r-S_l>r-l-(S_r-S_l)\rightarrow 2S_r-r>2S_l-l,也就是在求 Pi=2SiiP_i=2S_i-i 的逆序对个数。

对于同一个 wwPiP_i 可以划分成若干个单调递减区间,总数在 O(n)O(n) 级别。同一个区间内是没有贡献的,只需要计算 ll 在前面区间内的贡献。

cic_i 代表 iiPP 中的出现次数(由于可能有负的,所以需要加上一个偏移量),TT 表示 cic_i 的前缀和,那么每一个 PiP_i 的贡献就是当前的 TPi1T_{P_i-1}。对于一段 [x,y][x,y],总贡献就是 i=x1y1Ti\sum\limits_{i=x-1}^{y-1}T_i,再求一个 TT 的前缀和 GG 即可。

这个东西可以使用树状数组维护,先将 cc 差分得到数组 dd(因为对于 cc 要进行区间修改),然后:

Gx=i=1xTi=i=1xj=1icj=i=1xj=1ik=1jdk=i=1x(x+2i)(x+1i)2di\begin{aligned} G_x&=\sum_{i=1}^{x} T_i\\ &=\sum_{i=1}^{x}\sum_{j=1}^{i}c_j\\ &=\sum_{i=1}^{x}\sum_{j=1}^{i}\sum_{k=1}^j d_k\\ &=\sum_{i=1}^x \frac{(x+2-i)(x+1-i)}{2} d_i \end{aligned}

就可以使用三个树状数组维护了。

查看代码
#include <bits/stdc++.h>
#define lowbit(x) (x & -x)
using namespace std;
typedef long long i64;

int n, type;
int a[500005];
vector<int> b[500005]; 

i64 C1[1000005], C2[1000005], C3[1000005]; 
i64 sum(int x) {
    i64 res = 0; 
    for (int i = x; i > 0; i -= lowbit(i)) 
        res += C1[i] * (x + 2) * (x + 1) - C2[i] * (2 * x + 3) + C3[i];
    return res;
}
void add(int x, i64 k) {
    for (int i = x; i <= 2 * n + 1; i += lowbit(i)) 
        C1[i] += k, C2[i] += k * x, C3[i] += k * x * x;
}

int main(void) {
    scanf("%d%d", &n, &type);
    for (int i = 1; i <= n; ++i) scanf("%d", a + i), b[a[i]].emplace_back(i);
    i64 ans = 0; const int N = n + 1; 
    for (int i = 0; i < n; ++i) {
        b[i].emplace_back(n + 1); int last = 0;
        for (int j = 0; j < b[i].size(); ++j) {
            int x = 2 * j - (b[i][j] - 1) + N, y = 2 * j - last + N;  
            ans += sum(y - 1) - sum(x - 2); 
            add(x, 1); add(y + 1, -1);
            last = b[i][j];
        }
        last = 0; 
        for (int j = 0; j < b[i].size(); ++j) {
            int x = 2 * j - (b[i][j] - 1) + N, y = 2 * j - last + N; 
            add(x, -1); add(y + 1, 1);
            last = b[i][j]; 
        }
    }
    printf("%lld\n", ans >> 1); return 0;
}

[RC-03] 记忆

Portal.

考虑使用动态规划解决这个问题的静态版本,一操作会导致 ansans 增大 cnt+1cnt+1cntcnt 增大 11。二操作会导致 ansans 增大 11cntcnt 清零。转移可以使用矩阵刻画,线段树维护时间轴即可。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

struct Matrix {
    i64 a[3][3]; 
    Matrix() { memset(a, 0, sizeof a); }
    friend Matrix operator* (const Matrix &a, const Matrix &b) {
        Matrix c; 
        for (int i = 0; i < 3; ++i) for (int k = 0; k < 3; ++k) {
            i64 r = a.a[i][k]; 
            for (int j = 0; j < 3; ++j) c.a[i][j] += r * b.a[k][j]; 
        }
        return c; 
    }
} T[800005], A, B, C; 

void update(int o, int l, int r, int x, int k) {
    if (l == r) return T[o] = (k == 1 ? A : (k == 2 ? B : C)), void(); 
    int mid = l + r >> 1; 
    if (x <= mid) update(o << 1, l, mid, x, k); 
    else update(o << 1 | 1, mid + 1, r, x, k); 
    T[o] = T[o << 1] * T[o << 1 | 1]; 
}

int n; 
int op[200005], p[200005]; 
bool tag[200005]; 

int main(void) {
    A.a[0][0] = A.a[1][0] = A.a[1][1] = A.a[2][0] = A.a[2][1] = A.a[2][2] = 1; 
    B.a[0][0] = B.a[2][0] = B.a[2][1] = B.a[2][2] = 1; 
    C.a[0][0] = C.a[1][1] = C.a[2][2] = 1; 
    scanf("%d", &n); 
    for (int i = 1; i <= n * 4; ++i) T[i] = C; 
    for (int i = 1; i <= n; ++i) {
        scanf("%d", op + i); if (op[i] != 3) p[i] = i; else scanf("%d", p + i); 
        if (op[i] == 1) update(1, 1, n, i, 1); 
        else if (op[i] == 2) update(1, 1, n, i, 2);
        else {
            p[i] = p[p[i]]; 
            if (tag[p[i]]) tag[p[i]] = 0, update(1, 1, n, p[i], op[p[i]]); 
            else tag[p[i]] = 1, update(1, 1, n, p[i], 3); 
        }
        Matrix tmp; tmp.a[0][0] = tmp.a[0][1] = tmp.a[0][2] = 1; 
        printf("%lld\n", (tmp * T[1]).a[0][0]); 
    }
    return 0;
}

[省选联考 2020 A/B 卷] 冰火战士

Portal.

将温度离散化,那么求的就是冰人前缀和(IpI_p)和火人后缀和(前缀和记为 FpF_p)的最小值最大为多少。由于能力值不为负,因此只需要求出 IpFsumFp1I_p\le F_{sum}-F_{p-1} 的最大 ppIpFsumFp1I_p\ge F_{sum}-F_{p-1} 的最大 pp(但是 Fp1F_{p-1} 最小)。

前者好搞,但是后者怎么求?考虑将 FF 平移一位,条件一变成 IpFsumFpI_p\le F_{sum}-F_{p},条件二变成 IpFsumFpI_p\ge F_{sum}-F_{p}。当求出前面的 pp 后,取 pp+1p\leftarrow p+1,那么此时 pp 就是满足条件二的最小 pp,然后再次倍增出最大的 pp 即可。

查看代码
#include <bits/stdc++.h>
using namespace std; 
const int N = 2e6 + 5; 

int Q, n, b[N]; 
int op[N], t[N], x[N], y[N]; 

int ice[N], fire[N], FS; 
void add(int x, int k, int *c) { while (x <= n) c[x] += k, x += x & -x; }

int main(void) {
    scanf("%d", &Q); int tot = 0; 
    for (int i = 1; i <= Q; ++i) {
        scanf("%d", op + i); 
        if (op[i] == 1) scanf("%d%d%d", t + i, x + i, y + i), b[++tot] = x[i]; 
        else scanf("%d", t + i); 
    }
    sort(b + 1, b + tot + 1); n = unique(b + 1, b + tot + 1) - (b + 1); 
    for (int i = 1; i <= Q; ++i) x[i] = lower_bound(b + 1, b + n + 1, x[i]) - b; 

    for (int i = 1; i <= Q; ++i) {
        if (op[i] == 2) x[i] = x[t[i]], y[i] = -y[t[i]], t[i] = t[t[i]]; 
        if (t[i] == 0) add(x[i], y[i], ice); 
        else add(x[i] + 1, y[i], fire), FS += y[i]; 
        
        int I = 0, F = 0, p = 0; 
        for (int j = 20; j >= 0; --j) {
            p ^= 1 << j; 
            if (p > n || I + ice[p] > FS - F - fire[p]) p ^= 1 << j; 
            else I += ice[p], F += fire[p]; 
        }
        int tot = I; 
        if (p < n) {
            int x = p + 1, _F = 0; 
            for (; x; x -= x & -x) _F += fire[x]; 
            if (I <= FS - _F) {
                p = F = 0; tot = FS - _F; 
                for (int j = 20; j >= 0; --j) {
                    p ^= 1 << j; 
                    if (p > n || F + fire[p] > _F) p ^= 1 << j; 
                    else F += fire[p]; 
                }
            }
        }

        if (tot) printf("%d %d\n", b[p], tot * 2); 
        else puts("Peace"); 
    }
    return 0;
}

[CF187D] BRT Contract

Portal.

如果等了一个灯那么后面就都是一样的了,这个因此问题是如何找到第一个等的灯。

从开始到位置 ii 的距离模 m=(g+r)m=(g+r) 的余数为 pp,出发时间为 tt,如果 g(t+p)mod(g+r)g\le (t+p)\bmod (g+r) 就需要等这个红灯。如果要等红灯,tt 的取值有两种情况:

pg:t[gp,g+rp1]p>g:t[0,g+rp1][mp+g,m1]p\le g: t\in [g-p,g+r-p-1]\\ p>g: t\in [0,g+r-p-1]\cup [m-p+g,m-1]

那么搞一个区间染色单点查询的动态开点线段树就行。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

struct Node {
    int ls, rs; 
    int val; 
} T[20000005]; 

int n, g, r, m, q, tot = 1, rt;
i64 d[100005], f[100005]; 
inline void pushdown(int o, int l, int r) {
    if (r - l + 1 > 1 && T[o].val) {
        if (!T[o].ls) T[o].ls = ++tot; if (!T[o].rs) T[o].rs = ++tot; 
        T[T[o].ls].val = T[T[o].rs].val = T[o].val;
        T[o].val = 0; 
    }
}
void update(int o, int l, int r, int x, int y, int k) {
    pushdown(o, l, r); 
    if (x <= l && r <= y) return T[o].val = k, void(); 
    int mid = l + r >> 1; 
    if (x <= mid) update(T[o].ls, l, mid, x, y, k);
    if (mid < y) update(T[o].rs, mid + 1, r, x, y, k);  
}
int query(int o, int l, int r, int x) {
    pushdown(o, l, r); 
    if (l == r) return T[o].val; 
    int mid = l + r >> 1; 
    if (x <= mid) return query(T[o].ls, l, mid, x); 
    return query(T[o].rs, mid + 1, r, x); 
}
inline i64 query(int t) {
    int p = query(rt, 0, m - 1, t % m);
    i64 ans = t + f[p] + d[p]; 
    if (p <= n) ans += m - (d[p] + t) % m; 
    return ans; 
}

int main(void) {
    scanf("%d%d%d", &n, &g, &r); m = g + r; 
    for (int i = 1; i <= n + 1; ++i) scanf("%lld", d + i), d[i] += d[i - 1]; 
    update(rt, 0, m - 1, 0, m - 1, n + 1);
    for (int i = n; i >= 1; --i) {
        int p = m - d[i] % m; f[i] = query(p) - d[i] - p; p = d[i] % m; 
        if (p <= g) update(rt, 0, m - 1, g - p, g + r - p - 1, i);    
        else update(rt, 0, m - 1, 0, g + r - p - 1, i), update(rt, 0, m - 1, m - p + g, m - 1, i); 
    }
    for (scanf("%d", &q); q--; ) {
        int t; scanf("%d", &t); 
        printf("%lld\n", query(t)); 
    }
    return 0;
}

线段树分裂与合并

几乎都是合并的题。

[POI2011] ROT-Tree Rotations

Portal.

直接 dfs 遍历这棵树,用线段树维护权值。交换左右子树只会对跨越这两棵子树的逆序对产生影响,统计这个就只剩线段树合并了。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 
const int N = 200000;

int n; i64 ans, u, v; 
struct Node {
    int ls, rs; 
    int siz; 
} T[22 * N];
int tot; 
void update(int &o, int l, int r, int x) {
    if (!o) o = ++tot; T[o].siz += 1; 
    if (l == r) return; int mid = l + r >> 1;
    if (x <= mid) update(T[o].ls, l, mid, x);
    else update(T[o].rs, mid + 1, r, x);
    return; 
}
int merge(int p, int q, int l, int r) {
    if (!p || !q) return p + q; 
    if (l == r) return T[p].siz += T[q].siz, p; 
    int mid = l + r >> 1; 
    u += 1ll * T[T[p].rs].siz * T[T[q].ls].siz; 
    v += 1ll * T[T[p].ls].siz * T[T[q].rs].siz; 
    T[p].ls = merge(T[p].ls, T[q].ls, l, mid);
    T[p].rs = merge(T[p].rs, T[q].rs, mid + 1, r);
    T[p].siz = T[T[p].ls].siz + T[T[p].rs].siz; 
    return p; 
}

int dfs(void) {
    int pos = 0, val; scanf("%d", &val);
    if (!val) {
        int ls = dfs(), rs = dfs(); u = v = 0; 
        pos = merge(ls, rs, 1, n); ans += min(u, v);
    } else update(pos, 1, n, val);
    return pos; 
}

int main(void) {
    scanf("%d", &n); dfs();
    printf("%lld\n", ans);
    return 0;
}

[湖南集训] 更为厉害

Portal.

如果 aabb 的祖先,那么可以随便选。否则只能选 bb 子树内的 cc(而且不能选 bb),线段树合并预处理出每棵子树以深度为值域的值域线段树。

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int n, m, dep[300005], siz[300005]; 
vector<int> G[300005];

struct Node {
    int ls, rs; 
    i64 val; 
} T[9000005];
int tot, root[300005]; 

void update(int &o, int l, int r, int x, int k) {
    if (!o) o = ++tot; T[o].val += k; 
    if (l == r) return;
    int mid = l + r >> 1; 
    if (x <= mid) update(T[o].ls, l, mid, x, k); 
    else update(T[o].rs, mid + 1, r, x, k);
}
i64 query(int o, int l, int r, int x, int y) {
    if (!o) return 0; 
    if (x <= l && r <= y) return T[o].val; 
    int mid = l + r >> 1; i64 res = 0; 
    if (x <= mid) res += query(T[o].ls, l, mid, x, y); 
    if (mid < y) res += query(T[o].rs, mid + 1, r, x, y);
    return res; 
}
int merge(int p, int q, int l, int r) {
    if (p == 0 || q == 0) return p + q; 
    int mid = l + r >> 1, o = ++tot; 
    T[o].ls = merge(T[p].ls, T[q].ls, l, mid); 
    T[o].rs = merge(T[p].rs, T[q].rs, mid + 1, r); 
    T[o].val = T[p].val + T[q].val;
    return o; 
}

void dfs(int x, int fa) {
    dep[x] = dep[fa] + 1; siz[x] = 1; 
    for (int y : G[x]) if (y != fa) dfs(y, x), siz[x] += siz[y]; 
    update(root[x], 1, n, dep[x], siz[x] - 1);
    if (fa) root[fa] = merge(root[fa], root[x], 1, n);
}

int main(void) {
    scanf("%d%d", &n, &m); 
    for (int i = 1; i < n; ++i) {
        int u, v; scanf("%d%d", &u, &v); 
        G[u].emplace_back(v); G[v].emplace_back(u); 
    } dfs(1, 0); 
    while (m--) {
        int x, k; scanf("%d%d", &x, &k); 
        printf("%lld\n", query(root[x], 1, n, dep[x] + 1, dep[x] + k) + 1ll * min(k, dep[x] - 1) * (siz[x] - 1)); 
    }
    return 0;
}

[NOI2020] 命运

Portal.

考虑设 fx,yf_{x,y} 代表以 xx 为根的子树中已经全部满足,不满足的距离最多为 yy(从根节点向下开始)的方案数,答案为 f1,0f_{1,0}

考虑每次将 (x,y)(x,y) 合并进当前答案,分别考虑这条边填 1/01/0 的贡献:

fx,ij=0depxfx,ify,j+j=0ifx,ify,j+j=0i1fx,jfy,if'_{x,i} \leftarrow \sum_{j=0}^{dep_x} f_{x,i}f_{y,j}+\sum_{j=0}^{i} f_{x,i}f_{y,j}+\sum_{j=0}^{i-1} f_{x,j}f_{y,i}

gx,y=i=0yfx,ig_{x,y}=\sum_{i=0}^y f_{x,i},则:

fx,i=fx,i(gy,depx+gy,i)+fy,igx,i1f'_{x,i}=f_{x,i}(g_{y,dep_x}+g_{y,i})+f_{y,i}g_{x,i-1}

所有的转移位置都只与深度有关,因此直接线段树合并,维护区间乘法的修改。

查看代码
#include <bits/stdc++.h>
using namespace std;
const int P = 998244353; 

int n, q; 
vector<int> G[500005], p[500005]; 
int dep[500005]; 
int f[5005][5005], g[5005][5005]; 

struct Node {
    int ls, rs; 
    int dat, tag; 
    #define ls(x) T[x].ls
    #define rs(x) T[x].rs
    #define dat(x) T[x].dat
    #define tag(x) T[x].tag
} T[10000005]; 
int tot, rt[2000005]; 

void update(int &o, int l, int r, int x) {
    o = ++tot; dat(o) = tag(o) = 1; 
    if (l == r) return; 
    int mid = l + r >> 1; 
    if (x <= mid) update(ls(o), l, mid, x); 
    else update(rs(o), mid + 1, r, x); 
}
inline void pushdown(int o) {
    if (ls(o)) {
        dat(ls(o)) = 1ll * dat(ls(o)) * tag(o) % P; 
        tag(ls(o)) = 1ll * tag(ls(o)) * tag(o) % P; 
    }
    if (rs(o)) {
        dat(rs(o)) = 1ll * dat(rs(o)) * tag(o) % P; 
        tag(rs(o)) = 1ll * tag(rs(o)) * tag(o) % P; 
    }
    tag(o) = 1; 
}
int query(int o, int l, int r, int x) {
    if (!o || r <= x) return dat(o); pushdown(o); 
    int mid = l + r >> 1, ans = query(ls(o), l, mid, x); 
    if (mid < x) ans = (ans + query(rs(o), mid + 1, r, x)) % P; 
    return ans; 
}
int merge(int x, int y, int l, int r, int &s1, int &s2) { // s1 为 g(y, i),s2 为 g(x, i-1)
    if (!x && !y) return 0; 
    if (!x || !y) {
        if (!x) {
            s1 = (s1 + dat(y)) % P; 
            dat(y) = 1ll * dat(y) * s2 % P; 
            tag(y) = 1ll * tag(y) * s2 % P; 
            return y; 
        }
        s2 = (s2 + dat(x)) % P; 
        dat(x) = 1ll * dat(x) * s1 % P; 
        tag(x) = 1ll * tag(x) * s1 % P; 
        return x; 
    }
    if (l == r) {
        int tmp = dat(x); s1 = (s1 + dat(y)) % P; 
        dat(x) = (1ll * dat(x) * s1 + 1ll * dat(y) * s2) % P; 
        s2 = (s2 + tmp) % P; 
        return x; 
    }
    pushdown(x); pushdown(y); int mid = l + r >> 1; 
    ls(x) = merge(ls(x), ls(y), l, mid, s1, s2); 
    rs(x) = merge(rs(x), rs(y), mid + 1, r, s1, s2); 
    dat(x) = (dat(ls(x)) + dat(rs(x))) % P; 
    return x; 
}

void dfs(int x, int fa) {
    dep[x] = dep[fa] + 1; int mx = 0; 
    for (int i : p[x]) mx = max(mx, dep[i]); 
    update(rt[x], 0, n, mx); 
    for (int y : G[x]) if (y != fa) {
        dfs(y, x); 
        int s1 = query(rt[y], 0, n, dep[x]), s2 = 0; 
        rt[x] = merge(rt[x], rt[y], 0, n, s1, s2); 
    }
}

int main(void) {
    scanf("%d", &n); 
    for (int i = 1; i < n; ++i) {
        int x, y; scanf("%d%d", &x, &y); 
        G[x].emplace_back(y); G[y].emplace_back(x); 
    } scanf("%d", &q); 
    while (q--) {
        int x, y; scanf("%d%d", &x, &y); 
        p[y].emplace_back(x); 
    } dfs(1, 0); 
    return !printf("%d\n", query(rt[1], 0, n, 0));  
}

评论

若无法加载,请尝试刷新,欢迎讨论、交流和提出意见,支持 Markdown 与 LaTeX 语法(公式与文字间必须有空格)!