[백준] C++ 1647번: 도시 분할 계획 - MST 입니다. 그런데 이제 두 개로 분할을 곁들인,,,

2024. 10. 4. 08:45알고리즘/문제풀이

https://www.acmicpc.net/problem/1647

예제

7 12
1 2 3
1 3 2
3 2 1
2 5 2
3 4 4
7 3 6
5 1 5
1 6 2
6 4 1
6 5 3
4 5 3
6 7 4

ans : 8

문제의 요구사항은 다음과 같다.

  1. 도시를 2개로 분할하라.
  2. 분할된 2개의 도시 사이에는 간선이 존재하지 않는다.
  3. 각 도시의 마을은 MST로 구성되어 있다.

일단 3번 조건을 보면 MST를 구현해야하긴 하는데 어떻게 2개로 분할해야 가장 Minimum한 MST를 만들 수 있을지가 고민된다.

N이 100개정도면 완전탐색으로 구현해도 될테지만 N은 10만... 완전탐색은 아니다.

 

만약 도시 분할 조건이 없다면 바로 MST를 구해도 된다. 그렇다면 MST를 구하는 과정에 Union-Find를 하게 되고 이 부분을 응용해서 2개의 집합을 구할 수 있지 않을까?

Union-Find 하는 과정을 생각해보면 2개의 집합을 유지한다는 것 자체가 복잡하고 Greedy하게 집합을 나누는 방법밖에 떠오르지 않는데 Greedy한게 Minimum하다는 증명도 없다.

 

그런데 조금 다른 방식으로 생각해보면 MST에서 간선 하나만 제거해도 두 개의 집합이 된다. 가장 큰 간선을 제거하면 그 MST 안에서는 가장 큰 간선을 제거한 두 집합이 Minimum하다.

그럼 원래 input graph에서도 그 간선을 제거한게 Minimum할까? 그러니까 이 알고리즘에서 걱정되는 부분은 전체 그래프를 기준으로 생성한 MST에서 가장 큰 간선을 제거한 트리보다 input 그래프에서 특정 간선을 제거하여 새로 만든 MST가 더 Minimum한 경우는 없을까? 

조금 생각을 해보자.

 

input 그래프에서 특정 간선을 제거하여 새로 만든 MST가 더 Minimum 하려면 당장 생각나는건 input 그래프에서 특정 간선을 제거하여 두 집합이 생기는 경우이다.

아래의 그림에서 간선 하나만 있다고 생각해보면 그 간선은 이미 MST에 포함되어 있고, 그래서 이 케이스는 고려할 필요가 없다. 좀 더 일반적인 경우는 어떨까?

 

input 그래프는 어떠한 형태로던 아래와 같은 두 집합으로 분리할 수 있다.

두 집합을 연결하는 간선들 중 가장 작은 간선을 제외한 간선들은 MST를 만드는 과정에서 제거되고, 가장 작은 간선은 MST에 포함되기에 위에서 말한 케이스와 같다.

 

정리하자면

  • MST는 이미 전체 그래프에서 가능한 가장 작은 가중치의 간선들로 구성되어 있다.
  • MST에서 가장 큰 간선을 제거하면, Minimum한 두 개의 트리로 분리된다.
  • 이 두 트리는 원래 그래프에서 가능한 모든 분할 중 가장 효율적인 분할이다. 왜냐하면, 제거된 간선은 두 트리를 연결하는 간선 중 가장 큰 것이므로, 다른 어떤 분할도 이보다 작은 간선을 제거할 수 없다.

증명이 좀 길었는데, 내가 수학적으로 접근한 것은 아니고 떠올린 알고리즘을 최대한 논리적으로 증명한 것이다. 좀 더 수학적이고 명확한 증명이 필요해보이지만 문제풀이에서는 그렇지 않다. 일단 이대로 짜보고, 제출해보고 틀리면 그 때 찾아보면 된다. 한 번 구현해보자!

 

크루스칼 알고리즘으로 구현한 전체 알고리즘은 다음과 같다. 

  1. 모든 간선을 오름차순으로 정렬한다.
  2. 정렬된 간선을 순회하며 두 정점이 다른 집합에 속하면 MST에 추가하고, Union 한다.
  3. 마지막으로 연결된 간선을 제거한다.

전체 코드는 다음과 같다.

/** MST 1647 도시 분할 계획 **/
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <tuple>
using namespace std;

int n, m;
vector<int> p(100010, -1);
tuple<int, int, int> edge[1000010];

int find(int x)
{
    if (p[x] < 0)
        return x;
    return p[x] = find(p[x]);
}

bool is_diff_group(int u, int v)
{
    u = find(u);
    v = find(v);
    // 같은 그룹
    if (u == v)
        return false;

    if (p[u] == p[v])
        p[u]--;
    // u가 더 깊은 경우(p가 작아야 함.)
    if (p[u] < p[v])
        p[v] = u;
    else
        p[u] = v;
    return true;
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n >> m;
    for (int i = 0; i < m; i++)
    {
        int a, b, c;
        cin >> a >> b >> c;
        edge[i] = {c, a, b};
    }

    sort(edge, edge + m);
    int cnt = 0, ans = 0;

    for (int i = 0; i < m; i++)
    {
        int u, v, cost;
        tie(cost, u, v) = edge[i];
        if (!is_diff_group(u, v))
            continue;
        cnt++;
        if (cnt == n - 1)
        {
            // 마지막 iteration
            break;
        }
        ans += cost;
    }

    cout << ans;
}