[알고리즘] 최소 신장 트리 - 크루스칼 알고리즘과 Union Find, 백준 1197번 C++ 풀이

2024. 9. 9. 22:27알고리즘/알고리즘 지식

 

[실전 알고리즘] 0x1B강 - 최소 신장 트리

안녕하세요, 오늘 다룰 주제는 최소 신장 트리(Minimum Spanning Tree)라는 개념입니다. 보통 코딩 좀 치는 사람들 사이에서는 MST라고 많이들 부릅니다. 그런데 Spanning이나 신장 이 단어가 너무 낯설지

blog.encrypted.gg

바킹독님 블로그를 참조하여 쓴 글입니다.

 

신장 트리란?

신장 트리는 무방향 그래프의 부분 그래프들 중에서 모든 정점을 포함하는 트리이다. 

사진의 왼쪽 그래프에서 오른쪽 그래프는 모두 신장 트리이다. 신장 트리라 하면 당연히 연결 그래프가 되겠고, 트리이므로 당연히 사이클이 존재해선 안된다.

 

최소 신장 트리는 신장 트리 중에서 간선의 합이 최소인 트리를 말한다. 최소 신장 트리는 동일한 그래프에서 여러 개 존재할 수 있다.

 

이제 최소 신장 트리를 구하는 방법을 알아보자.


크루스칼 알고리즘

  1. 간선을 크기 기준 오름차순으로 정렬하고 제일 낮은 비용의 간선을 선택한다.
  2. 현재 선택한 간선이 정점 u, v를 연결하는 간선이라고 할 때 만약 u, v가 같은 그룹이라면 continue; u, v가 다른 그룹이라면 같은 그룹으로 만들고 현재 선택한 간선을 최소 신장 트리에 추가.
  3. 최소 신장 트리에 V-1개의 간선을 추가시켰다면 과정을 종료, 그렇지 않다면 그다음으로 비용이 작은 간선을 선택한 후 2번 과정을 반복한다.

크루스칼 알고리즘은 Union Find 알고리즘을 알아야 효율적으로 구현할 수 있다. 


Union Find 알고리즘

Union Find를 간단하게 알아보자면 서로서 집합을 관리하는 자료구조이다. 주로 특정 원소가 속한 집합을 찾거나 원소가 속한 집합을 조작하는 자료구조이다. 

 

주로 두 가지 기본 연산을 지원한다.

  1. Find(x) : x가 속한 집합의 대표자를 반환한다.
  2. Union(x, y) : x와 y가 속한 두 집합을 하나로 합친다.

 

효율적인 Union Find 알고리즘을 위해 두 가지 최적화 기법이 사용된다.

  1. 경로 압축 : Find 연산에서 탐색하는 노드의 부모를 해당 집합의 대표 노드를 가리키게 하여 이후 Find 연산의 경로를 최적화한다.
  2. Union By Rank : 두 집합을 합칠 때 더 작은 집합을 더 큰 집합에 병합한다. 이는 트리의 높이를 줄일 수 있다.

간단하게 구현해보자.

#include <iostream>
#include <vector>

using namespace std;

// 부모 노드를 저장할 벡터
vector<int> parent;

// Find 연산: 경로 압축 적용
int find(int x) {
    if (parent[x] != x) {
        parent[x] = find(parent[x]);  // 경로 압축
    }
    return parent[x];
}

// Union 연산: 두 집합을 병합
void unionSets(int x, int y) {
    int rootX = find(x);
    int rootY = find(y);

    if (rootX != rootY) {
        parent[rootY] = rootX;  // rootY를 rootX에 병합
    }
}

이 정도로 간단하게 알아보고 다시 크루스칼을 구현해 보자.


 

크루스칼 알고리즘에서는 특정 두 정점이 같은 그룹인지, 다른 그룹인지 판단해야 한다. Union Find를 사용하지 않으면 Flood Fill을 이용해 판단할 수 있다. 현재까지 만든 최소 신장 트리에서 A에서 B로 방문이 가능한지 Flood Fill을 돌렸을 때 방문이 가능하다면 같은 그룹, 아니라면 다른 그룹이다. 하지만 이 경우에 시간복잡도가 O(V+E)가 되고 최대 E번 판단이 필요하므로 O(VE)가 된다. 

 

그렇게 비효율적인가 싶겠지만 Flood Fill 대신 Union Find를 사용하면 상수 시간에 특정 두 정점의 그룹을 판단할 수 있기에 굉장히 효율적으로 구현할 수 있다.

 

바로 코드로 보자.

int v, e;
tuple<int, int, int> edge[100500];

sort(edge, edge+e);
int cnt=0;
for(int i=0; i<e; i++){
    int cost, a, b;
    tie(cost, a, b) = edge[i];
    if(!is_diff_group(a, b)) continue;
    cout << cost << ' ' << a << ' ' << b;
    cnt++;
    if(cnt == v-1) break;
}

is_diff_group 함수는 특정 두 정점이 같은 그룹인지 다른 그룹인지 판단하는 Union Find 함수라고 생각하자. 또, tuple의 대소 비교는 제일 앞의 값부터 이루어지기에 비용, 정점 1, 정점 2 순으로 저장해야 sort가 정상적으로 작동한다.

 

https://www.acmicpc.net/problem/1197

바로 최소 스패닝 트리 문제를 풀어보자. 문제부터 최소 스패닝 트리 간선의 합을 구하라 했으니 최소 스패닝 트리의 간선을 구하면 된다. 사실상 위에서 짠 크루스칼에 Union Find를 구현하면 되는데 경로 압축만 해서 간단하게 구현해 보았다.

int find(int x)
{
    // x가 대표 노드인 경우
    if (p[x] < 0)
        return x;
    // 경로 압축
    return p[x] = find(p[x]);
}

bool is_diff_group(int u, int v)
{
    u = find(u);
    v = find(v);
    if (u == v)
        return 0;

    p[v] = u;

    return 1;
}

Union Find 부분이다. find로 경로 압축을 하면서 대표자 노드를 찾는 과정은 떠오르는 대로 구현하면 되고, is_diff_group가 union이 되겠다. 두 정점이 같은 그룹인지 판단한 뒤 다른 그룹이라면 어차피 병합해야 하므로 여기서 병합을 진행해 주었다. p [v] = u; 가 실질적으로 병합을 하는 과정인데 간단하게 v의 대표자 노드를 u로 바꿔주면 된다.

 

/** MST 1197 최소 스패닝 트리 **/
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <tuple>
using namespace std;
int v, e;
vector<int> p(10100, -1);
tuple<int, int, int> edge[100005];

int find(int x)
{
    // x가 대표 노드인 경우
    if (p[x] < 0)
        return x;
    // 경로 압축
    return p[x] = find(p[x]);
}

bool is_diff_group(int u, int v)
{
    u = find(u);
    v = find(v);
    if (u == v)
        return 0;

    p[v] = u;

    return 1;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> v >> e;
    for (int i = 0; i < e; i++)
    {
        int a, b, cost;
        cin >> a >> b >> cost;
        edge[i] = {cost, a, b};
    }

    sort(edge, edge + e);
    int cnt = 0, ans = 0;
    for (int i = 0; i < e; i++)
    {
        int a, b, cost;
        tie(cost, a, b) = edge[i];
        if (!is_diff_group(a, b))
            continue;
        ans += cost;
        cnt++;
        if (cnt == v - 1)
            break;
    }
    cout << ans;
}

 

코드 전체를 보면 간단하게 Union Find를 사용해서 크루스칼을 순서대로 구현해 주면 된다. 다시 복습하자면 크루스칼 알고리즘은

  1. 간선을 가중치 기준으로 오름차순 정렬.
  2. 가장 낮은 간선부터 해당 간선이 a부터 b라면 a, b가 같은 그룹인지 판단.
  3. 같은 그룹이라면 continue; 아니라면 MST의 간선!

위의 알고리즘 순서대로 구현하면 된다.