[백준] C++ 1368번: 물대기 - MST 변형, 프림 알고리즘

2024. 9. 12. 14:38알고리즘/문제풀이

 

예제

4
5
4
4
3
0 2 2 2
2 0 3 3
2 3 0 4
2 3 4 0

출력 : 9

MST 문제인데 MST 문제는 구현도 어렵지만 왜 MST인지 잘 판단해야 한다. 논을 파는 방법은 직접 논에 우물을 파거나 다른 논에서 끌어오는 방법이 있다.

당장 떠오르는 방법은 그리디 하게 가장 싼 우물을 파되, 불필요하게 사이클이 생겨서는 안 되고 최소한의 개수만큼 파야 한다. 사이클이 생겨서는 안 되고 최소한의 개수만큼 파야한다는 점에서 MST를 떠올렸고 프림 알고리즘을 선택했다.

 

알고리즘은 다음과 같다.

  1. 우물을 파는 비용을 모두 우선순위 큐에 넣는다.
  2. 우선순위 큐에서 논을 하나 빼서 MST에 포함시킨다.
  3. MST에 포함된 논의 간선 중 MST가 아닌 논과 연결된 간선을 우선순위 큐에 넣는다.
  4. N개의 논에 대해 N-1개의 간선이 MST에 포함될 때 까지 3번을 반복한다.
priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<tuple<int, int, int>>> pq;

이렇게 우선순위 큐를 선언했고 1, 2를 구현해보자.

int cost_, a_, b_;
tie(cost_, a_, b_) = pq.top();
pq.pop();
chk[a_] = true;
ans += cost_;

for (int nxt = 1; nxt <= n; nxt++)
{
    if (nxt == a_)
        continue;
    pq.push({adj[a_][nxt], a_, nxt});
}

tuple 우선순위 큐를 다루는 부분이 좀 헷갈린다. nxt == a_를 체크하는 이유는 1, 1과 같은 논의 adj 값이 0이기 때문에 우선순위 큐에 넣으면 안 된다.

 

이제 3, 4번을 구현해 보자.

 while (cnt < n - 1)
    {
        int cost, a, b;
        tie(cost, a, b) = pq.top();
        pq.pop();
        if (chk[b])
            continue;

        chk[b] = true;
        ans += cost;
        cnt++;

        for (int nxt = 1; nxt <= n; nxt++)
        {
            if (!chk[nxt] && nxt != b)
                pq.push({adj[b][nxt], b, nxt});
        }
    }

a는 MST에 포함된 노드이고 b는 포함되지 않은 노드이다. 따라서  pq.push({adj [b][nxt], b, nxt}); 방식으로 넣어야 한다. 

/** MST 1368 물대기 **/
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
#include <queue>
#include <tuple>
using namespace std;
int n;
int adj[400][400];
int water[400];
bool chk[400];
int cnt = 0, ans = 0;

priority_queue<tuple<int, int, int>, vector<tuple<int, int, int>>, greater<tuple<int, int, int>>> pq;

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> water[i];
        pq.push({water[i], i, i});
    }

    for (int i = 1; i <= n; i++)
        for (int j = 1; j <= n; j++)
            cin >> adj[i][j];

    int cost_, a_, b_;
    tie(cost_, a_, b_) = pq.top();
    pq.pop();
    chk[a_] = true;
    ans += cost_;

    for (int nxt = 1; nxt <= n; nxt++)
    {
        if (nxt == a_)
            continue;
        pq.push({adj[a_][nxt], a_, nxt});
    }

    while (cnt < n - 1)
    {
        int cost, a, b;
        tie(cost, a, b) = pq.top();
        pq.pop();
        if (chk[b])
            continue;

        chk[b] = true;
        ans += cost;
        cnt++;

        for (int nxt = 1; nxt <= n; nxt++)
        {
            if (!chk[nxt] && nxt != b)
                pq.push({adj[b][nxt], b, nxt});
        }
    }
    cout << ans;
}