抱歉,您的浏览器无法访问本站
本页面需要浏览器支持(启用)JavaScript
了解详情 >

本文已经更新完毕,若有错误或需要补充的内容请在评论区留言。

在无向图中,生成树指一棵由全部顶点和组成的树,而当中边权之和最小的生成树称为最小生成树(Minimum Spanning Tree,MST)。本文会引导你学习 MST 的 Kruskal 和 Prim 算法。

最小生成树

常见的求解 MST 的方法有两种:Kruskal 和 Prim。模板

Kruskal

Kruskal 基于贪心的思想。Kruskal 先把 mm 条边进行排序,然后检查每条边 u,vu,v,如果 uuvv 在同一个连通分量中,那么加入后就会形成环,不能加入。若不在呢?那就直接加入,一定是最优的。证明可以使用反证法,这里略去。

实现上,排序直接用 sort,维护的过程可以采用并查集,参考代码如下:

查看代码
#include <bits/stdc++.h>
using namespace std; 

int n, m, fa[5005]; 
int find(int x) { return fa[x] == x ? x : fa[x] = find(fa[x]); }
struct edge {
	int u, v, w; 
	bool operator< (const edge &a) const { return w < a.w; }
} e[200005]; 

int main(void) {
	scanf("%d%d", &n, &m); 
	for (int i = 1; i <= n; ++i) fa[i] = i; 
	for (int i = 1; i <= m; ++i) scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w); 
	sort(e + 1, e + m + 1); int tot = 1, ans = 0; 
	for (int i = 1; i <= m; ++i) {
		int u = find(e[i].u), v = find(e[i].v); 
		if (u == v) continue; 
		++tot; fa[u] = v; ans += e[i].w; 
		if (tot == n) break; 
	}
	if (tot != n) puts("orz"); 
	else printf("%d\n", ans); 
	return 0; 
}

时间复杂度 O(mlogn+nlogn)\mathcal{O}(m\log n+n\log n)(一般认为 m>nm > n,所以写作 O(mlogn)\mathcal{O}(m\log n))。

Prim

Prim 同样基于贪心的思想,读者应该了解过 Dijkstra 算法,Prim 和 Dijkstra 大概就是相同的原理。

Prim 算法维护的是 MST 的一部分。最初,Prim 确定 11 号节点属于 MST(即将 11 作为根节点)。

设确定属于 MST 的点集为 TT,未确定为 SS。Prim 会找到边权最小的边 (u,v),uT,vS(u,v),u\in T,v\in S,然后将这条边加入 MST。

实现时开一个 dd 数组,当 iSi\in S 时,d[i]d[i] 代表与集合 TT 中节点之间权值最小的边的权值,最终答案就是 d\sum d

发没发现这一过程很像 Dijkstra?的确如此,代码如下:

查看代码
#include <bits/stdc++.h>
using namespace std; 

int n, m; 
int G[5005][5005], d[5005], v[5005]; 

void Prim(void) {
	memset(d, 0x3f, sizeof d); d[1] = 0; 
	for (int op = 1; op < n; ++op) {
		int x = 0; 
		for (int i = 1; i <= n; ++i) if (!v[i] && (d[i] < d[x])) x = i; 
		v[x] = 1; 
		for (int y = 1; y <= n; ++y) if (v[y] == 0) d[y] = min(d[y], G[x][y]); 
	}
}

int main(void) {
	scanf("%d%d", &n, &m); memset(G, 0x3f, sizeof G); 
	while (m--) {
		int u, v, w; scanf("%d%d%d", &u, &v, &w); 
		G[u][v] = G[v][u] = min(G[u][v], w); 
	} Prim(); int ans = 0; 
	for (int i = 1; i <= n; ans += d[i++])
		if (d[i] > 1e9) return puts("orz"), 0; 
	return !printf("%d\n", ans); 
}

Prim 的复杂度是 O(n2)\mathcal{O}(n^2),虽然和 Dijkstra 一样可以用优先队列优化到 O(mlogn)\mathcal{O}(m\log n),但是这时就不如直接用 Kruskal。所以 Prim 用于稠密图(尤其是完全图)的 MST 求解。

Boruvka

对于一个点 ii,其最小权值的临边必定在 MST 上。那么迭代 logn\log n 次,每次扫描每条边,然后合并连通块。

算法时间复杂度为 O(mlogn)O(m\log n),但是实际中并不常用。实用的是这个思想。比如给定一张 nn 个点的完全图,边权通过某种方式计算。这时可以使用 Boruvka 算法,利用数据结构快速计算不在当前连通块的最小边权。

其它生成树

生成树问题有一些变种,这里简单介绍一下:

最小瓶颈生成树

这类问题形如这样:给出一个带权无向图,求一棵生成树,使得最大边权值尽量小。

怎么求呢?我们肯定要把所有边都排序,然后求解。等等,这不就是 Kruskal 算法吗?的确如此。原图的最小生成树就一定是最小瓶颈生成树(但要注意最小瓶颈生成树不一定是最小生成树)。

最小瓶颈路

求带权无向图 uuvv 的一条路径,使得这条路径上的最大边权值最小,这样的路被称为最小瓶颈路。

怎么做呢?可以使用二分 + 01 BFS 来解决,但效率较低。可以求原图的 MST,然后所有路径必定在 MST 上。为什么可以这么做呢?可以用反证法,会证明出这样一个结论:如果存在一条路径不在 MST 上,那么这个 MST 一定是假的。

次小生成树

模板

这里只讨论严格次小生成树,非严格的同理。

由于 Kruskal 算法的过程,可以证明次小生成树只和最小生成树有一条边差距。接下来考虑如何替换:

设最小生成树为 TT,权值和为 SS,那么遍历每一条边,加入一条边后树上会出现一个环,再断掉这个环中边权最大的边(若加入的也是最大的,那么需要断掉次大的,由于原来已经是 MST,显然加入的只能大于等于),对上述所有生成的答案取 min\min 之后就可以得到答案。

现在的问题就是,如何高校维护 u,vu,v 路径上的最大值呢?

采用树上倍增法,类似于 ST 表,存储每个点向上 2i2^i 条边的最大值与次大值,查询的时候倍增查询即可。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>

using namespace std;
typedef long long i64;
const int INF = 1e9;
const i64 INF64 = 2e18;

struct edge {
	int u, v, w;
	bool use;
	bool operator < (const edge &a) const {
		return w < a.w;
	}
} e[300005];

int n, m;
int bin[100005];
int find(int x) {
	if (bin[x] == x) return x;
	return bin[x] = find(bin[x]);
}

#define pii pair<int, int>
vector<pii> G[100005];

i64 ans0 = 0;
void Kruskal(void) 
{
	for (int i = 1; i <= n; ++i) bin[i] = i;
	sort(e + 1, e + m + 1);
    int tot = 0;
	for (int i = 1; i <= m; ++i) {
		int u = find(e[i].u), v = find(e[i].v);
		if (u != v) {
		    ans0 += e[i].w; bin[u] = v; e[i].use = true;
		    G[e[i].u].push_back({e[i].v, e[i].w});
		    G[e[i].v].push_back({e[i].u, e[i].w});
            ++tot;
        }
        if (tot == n - 1) break;
	}
}

int dep[100005], lg[100005];
int f[100005][20];
int mx[100005][20], mx2[100005][20];
void dfs(int x, int fa)
{
	dep[x] = dep[fa] + 1; f[x][0] = fa; mx2[x][0] = -INF;
	for (int i = 1; (1 << i) <= dep[x]; ++i) {
		f[x][i] = f[f[x][i - 1]][i - 1];
		int g[4] = {mx[x][i - 1], mx[f[x][i - 1]][i - 1], mx2[x][i - 1], mx2[f[x][i - 1]][i - 1]};
		sort(g, g + 4);
		mx[x][i] = g[3];
		int p = 2;
		while (p >= 0 && g[p] == g[3]) --p;
		mx2[x][i] = (p == -1 ? -INF : g[p]);
	}
	for (int i = 0; i < G[x].size(); ++i) {
		int y = G[x][i].first, w = G[x][i].second;
		if (y != fa) {
            mx[y][0] = w; 
            dfs(y, x);
        }
	}
}
int LCA(int x, int y) {
	if (dep[x] < dep[y]) swap(x, y);
	for (int i = lg[n]; i >= 0; --i)
		if (dep[f[x][i]] >= dep[y]) x = f[x][i];
	if (x == y) return x;
	for (int i = lg[n]; i >= 0; --i)
		if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];
	return f[x][0];
}

i64 calc(int x, int y, int val)
{
	int res = -INF;
	for (int i = lg[n]; i >= 0; --i)
		if (dep[f[x][i]] >= dep[y]) {
            if (val != mx[x][i]) res = max(res, mx[x][i]);
            else res = max(res, mx2[x][i]);
            x = f[x][i];
        }
    return res;
}

int main(void)
{
    ios::sync_with_stdio(false);
	scanf("%d%d", &n, &m);
    for (int i = 2; i <= n; ++i) lg[i] = lg[i >> 1] + 1;
	for (int i = 1; i <= m; ++i) scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].w);
	Kruskal();
	dfs(1, 0);
	i64 ans = INF64;
	for (int i = 1; i <= m; ++i)
		if (!e[i].use) {
            int lca = LCA(e[i].u, e[i].v);
            i64 tmpa = calc(e[i].u, lca, e[i].w);
            i64 tmpb = calc(e[i].v, lca, e[i].w);
            ans = min(ans, ans0 - max(tmpa, tmpb) + e[i].w);
        }
	if (ans != INF64) printf("%lld\n", ans);
    else puts("-1");
	return 0;
}

有向图中有一类生成树称为最小树形图。这个问题比较复杂,不在本文中讨论。感兴趣的同学可以自行了解。

还有一类问题称为 k 小生成树,但是这种问题的做法笔者暂时没有了解。据闻在大神刘汝佳的《算法艺术与信息学竞赛》P300 中有说明,感兴趣的读者可以自行挑战

Problemset

这里的题目都比较简单。

简单生成树

这是最基本的生成树问题。

[Luogu P1195] 口袋的天空

Portal.

这个 KK 是什么?不要紧,我们还是使用 Kruskal 算法,不过不一定要连成一棵树,我们只要把这些云连成 KK 个即可。也就是说,只需要连 NKN-K 条边。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;

int n, m, k;

struct edge
{
    int u, v, d;
    edge(int u = 0, int v = 0, int d = 0) :
        u(u), v(v), d(d) {}
    bool operator < (const edge &a) const
    {
        return d < a.d;
    }
}a[10005];

int fa[1005];
int find(int x)
{
    if (fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

int main(void)
{
    scanf("%d%d%d", &n, &m, &k);
    for (int i = 1; i <= n; ++i) fa[i] = i;
    for (int i = 1; i <= m; ++i)
        scanf("%d%d%d", &a[i].u, &a[i].v, &a[i].d);
    sort(a + 1, a + m + 1);
    int ans = 0, cnt = 0;
    for (int i = 1; i <= m; ++i)
    {
        int x = find(a[i].u), y = find(a[i].v);
        if (x == y) continue;
        fa[x] = y, ans += a[i].d, ++cnt;
        if (cnt == n - k)
        {
            printf("%d\n", ans);
            return 0;    
        }
    }   
    puts("No Answer");
    return 0;
}

[UVa 1395] Slim Span

Portal.

给定一个 n(n100)n(n\le 100) 个点的无向图,求最大边减最小边的值尽量小的生成树。

如果最小边确定,我们求出最小生成树,那么就可以求出这个值了。代码如下:

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

inline int read(void)
{
    int x = 0, c = getchar();
    while (!isdigit(c)) c = getchar();
    while (isdigit(c)) x = (x<<3) + (x<<1) + (c^48), c = getchar();
    return x;
}

int fa[105];
int find(int x) { return x == fa[x] ? x : fa[x] = find(fa[x]); }

struct edge
{
    int u, v, d;
    edge(int u, int v, int d) :
        u(u), v(v), d(d) {}
    inline bool operator < (const edge &a) const
    {
        return d < a.d;
    }
};
vector <edge> e;

int n, m;

inline int kruskal(void)
{
    sort(e.begin(), e.end());
    int ans = 0x7fffffff;
    for (int L = 0; L < m; ++L)
    {
        for (int i = 1; i <= n; ++i) fa[i] = i;
        int cnt = 0;
        for (int R = L; R < m; ++R)
        {
            int a = find(e[R].u), b = find(e[R].v);
            if (a == b) continue;
            fa[a] = b;
            if (++cnt == n - 1)
            {
                ans = min(ans, e[R].d - e[L].d);
                break;
            }
        }
    }
    if (ans == 0x7fffffff) return -1;
    return ans;
}

int main(void)
{
    while (scanf("%d%d", &n, &m) == 2 && n)
    {
        e.clear();
        for (int i = 0; i < m; ++i) 
        {
            int u = read(), v = read(), d = read();
            e.push_back(edge(u, v, d));
        }
        printf("%d\n", kruskal());
    }
    return 0;
}

[Luogu P2700] 逐个击破

Portal.

我们现假设需要摧毁所有的边,然后按边权从大到小排序。如果两个点都不是敌人节点就连边,注意父亲也要设置为敌人节点(如果连接的点有敌人)。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>

using namespace std;

struct edge
{
    int u, v, d;
    edge(int u = 0, int v = 0, int d = 0) :
        u(u), v(v), d(d) {}
    bool operator < (const edge &a) const
    {
        return d > a.d;
    }
}a[100005];

int n, k, s[100005];
int fa[100005];

int find(int x)
{
    if (fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}

int main(void)
{
    scanf("%d%d", &n, &k);
    for (int i = 1; i <= k; ++i)
    {
        int x;
        scanf("%d", &x);
        s[x] = true;
    }
    long long ans = 0;
    for (int i = 1; i < n; ++i) 
    {
        scanf("%d%d%d", &a[i].u, &a[i].v, &a[i].d);
        ans += a[i].d;
    }
    for (int i = 1; i <= n; ++i) fa[i] = i;
    sort(a + 1, a + n + 1);
    for (int i = 1; i < n; ++i)
    {
        int x = find(a[i].u), y = find(a[i].v);
        if (s[x] && s[y]) continue;
        fa[x] = y;
        ans -= a[i].d;
        s[y] = (s[x] | s[y]);
    }
    printf("%lld\n", ans);
    return 0;
}

[USACO08OCT] Watering Hole G

Portal.

我们只需要增设一个水井点 00,让每一个牧场都与 00 连一条 WiW_i 的边,然后使用 Prim(因为是完全图,开大数据范围即可杀死 Kruskal)求解最小生成树。

查看代码
#include <iostream>
#include <cstdio>
#include <cstring>

using namespace std;

int n, d[305];
bool v[305];
int a[305][305];

void Prim(void)
{
    memset(d, 0x3f, sizeof(d));
    d[0] = 0;
    for (int i = 1; i <= n; ++i)
    {
        int x = -1;
        for (int j = 0; j <= n; ++j)
            if (!v[j] && (x == -1 || d[j] < d[x])) x = j;
        v[x] = true;
        for (int j = 0; j <= n; ++j)
            if (!v[j]) d[j] = min(d[j], a[x][j]);
    }
}

int main(void)
{
    scanf("%d", &n);
    for (int i = 1; i <= n; ++i)
    {
        int w;
        scanf("%d", &w);
        a[0][i] = a[i][0] = w;
    }
    for (int i = 1; i <= n; ++i)
        for (int j = 1; j <= n; ++j)
            scanf("%d", &a[i][j]);
    Prim();
    int ans = 0;
    for (int i = 0; i <= n; ++i)
        ans += d[i];
    printf("%d\n", ans);
    return 0;
}

[UVa 1151] Buy or Build

Portal.

通过二维枚举,我们可以轻松的把这玩意转化成图。对着图使用 Kruskal,得到 n1n-1 条边,就是可能成为最终答案的边。然后枚举购买哪些套餐即可。

查看代码
#include <iostream>
#include <cstdio>
#include <vector>
#include <algorithm>

using namespace std;

int read(void)
{
    int x = 0, c = getchar(), f = 1;
    while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
    while (isdigit(c)) x = x * 10 + c - '0', c = getchar();
    return x * f;
}

int n, q, cost[8];
vector <int> sub[8];
int x[1005], y[1005];

struct edge
{
    int u, v, d;
    edge(int u = 0, int v = 0, int d = 0) :
        u(u), v(v), d(d) {}
    bool operator < (const edge &a) const
    {
        return d < a.d;
    }
};
vector <edge> e, es;

int fa[1005];
int find(int x)
{
    if (fa[x] == x) return x;
    return fa[x] = find(fa[x]);
}
inline void UnionFind_init(void)
{
    for (int i = 1; i <= n; ++i)
        fa[i] = i;
}

int kruskal(int cnt, const vector <edge> &G, bool flag)
{
    if (cnt == 1) return 0;
    int ans = 0;
    for (int i = 0; i < G.size(); ++i)
    {
        int x = find(G[i].u), y = find(G[i].v);
        if (x == y) continue;
        fa[x] = y;
        ans += G[i].d;
        if (flag) es.push_back(G[i]);
        --cnt;
        if (cnt == 1) break;
    }
    return ans;
}

int main(void)
{
    int T = read();
    while (T--)
    {
        n = read(), q = read();
        for (int i = 0; i < q; ++i)
        {
            int m = read(); cost[i] = read();
            sub[i].clear();
            while (m--) sub[i].push_back(read());
        }
        for (int i = 1; i <= n; ++i)
            x[i] = read(), y[i] = read();
        e.clear(), es.clear();
        for (int i = 1; i < n; ++i)
            for (int j = i + 1; j <= n; ++j)
                e.push_back(edge(i, j, (x[i]-x[j])*(x[i]-x[j]) + (y[i]-y[j])*(y[i]-y[j])));
        sort(e.begin(), e.end());

        UnionFind_init();
        int ans = kruskal(n, e, true);
        for (int i = 0; i < (1 << q); ++i)
        {
            UnionFind_init();
            int cnt = n, c = 0;
            for (int j = 0; j < q; ++j)
                if (i & (1 << j))
                {
                    c += cost[j];
                    for (int k = 1; k < sub[j].size(); ++k)
                    {
                        int x = find(sub[j][k]), y = find(sub[j][0]);
                        if (x != y) fa[x] = y, --cnt;
                    }
                }
            ans = min(ans, c + kruskal(cnt, es, false));
        }

        printf("%d\n", ans);
        if (T) putchar('\n');
    }   
    return 0;
}

[CF609E] Minimum spanning tree for each edge

Portal.

跟次小生成树的思路是一样的,在路径上找一条最大的边换下来即可。

查看代码
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;
typedef long long i64;

int n, m;
i64 ans = 0;

struct edge {
    int u, v, d, id;
    bool use;
    bool operator < (const edge &a) const {
        return d < a.d;
    }
} e[200005];
int bin[200005];
int find(int x) {
    if (bin[x] == x) return x;
    return bin[x] = find(bin[x]);
}

int f[20][200005], dep[200005], w[20][200005];
vector<pair<int, int>> G[200005];
void dfs(int x, int fa) {
    f[0][x] = fa; dep[x] = dep[fa] + 1;
    for (int i = 1; i <= 18; ++i) 
        f[i][x] = f[i - 1][f[i - 1][x]], w[i][x] = max(w[i - 1][x], w[i - 1][f[i - 1][x]]);
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i].first, d = G[x][i].second;
        if (y != fa) {
            w[0][y] = d;
            dfs(y, x);
        }
    }
}
int LCA(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    int ans = 0;
    for (int i = 18; i >= 0; --i) 
        if (dep[f[i][x]] >= dep[y])
            ans = max(ans, w[i][x]), x = f[i][x];
    if (x == y) return ans;
    for (int i = 18; i >= 0; --i)
        if (f[i][x] != f[i][y]) {
            ans = max({ans, w[i][x], w[i][y]});
            x = f[i][x], y = f[i][y];
        }
    return max({ans, w[0][x], w[0][y]});
}

int main(void) {
    scanf("%d%d", &n, &m);
    for (int i = 1; i <= m; ++i)
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].d), e[i].id = i;
    sort(e + 1, e + m + 1);
    for (int i = 1; i <= n; ++i) bin[i] = i;
    for (int i = 1; i <= m; ++i) {
        int u = e[i].u, v = e[i].v, d = e[i].d;
        int x = find(u), y = find(v);
        if (x != y) {
            bin[x] = y; e[i].use = true;
            G[u].push_back({v, d});
            G[v].push_back({u, d});
            ans += d;
        }
    }
    dfs(1, 0);
    static i64 p[200005];
    for (int i = 1; i <= m; ++i) {
        if (e[i].use) p[e[i].id] = ans;
        else p[e[i].id] = ans - LCA(e[i].u, e[i].v) + e[i].d;
    }
    for (int i = 1; i <= m; ++i)
        printf("%lld\n", p[i]);
    return 0;
}

[CF76A] Gift

Portal.

枚举能够使用的最大的 gg,如果一条边没能被选中则删掉这条边。时间复杂度 O(mnlogn)O(mn\log n)

查看代码
#include <bits/stdc++.h>
using namespace std;
typedef long long i64; 

int n, m, G, S, g[50005]; 
struct edge {
    int u, v, g, s; 
    bool operator < (const edge &a) const {
        return s < a.s; 
    }
} e[50005];
multiset<edge> E; 
i64 ans = 2e18; 

int f[205]; 
int find(int x) { if (f[x] == x) return x; return f[x] = find(f[x]); }

bool check(int mxg) {
    int cnt = 1; 
    for (int i = 1; i <= n; ++i) f[i] = i; 
    for (auto it = E.begin(); it != E.end(); ++it) {
        auto e = *it; if (e.g > g[mxg]) continue; 
        int x = find(e.u), y = find(e.v); 
        if (x == y) continue; 
        ++cnt; f[x] = y;
        if (cnt == n) break; 
    }
    return cnt == n; 
}

int main(void) {
    scanf("%d%d%d%d", &n, &m, &G, &S); 
    for (int i = 1; i <= m; ++i) scanf("%d%d%d%d", &e[i].u, &e[i].v, &e[i].g, &e[i].s), E.insert(e[i]), g[i] = e[i].g; 
    sort(g + 1, g + m + 1); int L = 0, R = m + 1; 
    while (L + 1 != R) {
        int mid = L + R >> 1; 
        if (check(mid)) R = mid; 
        else L = mid; 
    } if (R == m + 1) return puts("-1"), 0; 
    for (int mxg = R; mxg <= m; ++mxg) {
        int cnt = 1, now = 0; 
        for (int i = 1; i <= n; ++i) f[i] = i; 
        for (auto it = E.begin(); it != E.end(); ++it) {
            auto e = *it; if (e.g > g[mxg]) continue; 
            int x = find(e.u), y = find(e.v); 
            if (x == y) {
                auto id = it; --id; E.erase(it); 
                it = id; continue; 
            }
            ++cnt; f[x] = y; now = e.s;  
            if (cnt == n) break; 
        }
        if (cnt == n) ans = min(ans, 1ll * now * S + 1ll * g[mxg] * G); 
    }
    printf("%lld\n", ans); 
    return 0;
}

评论

若无法加载,请尝试刷新,欢迎讨论、交流和提出意见,支持 Markdown 与 LaTeX 语法(公式与文字间必须有空格)!