[BOJ] C++ 14501: 퇴사 - 백트래킹과 재귀

2023. 12. 18. 12:02알고리즘/문제풀이

 

14501번: 퇴사

첫째 줄에 백준이가 얻을 수 있는 최대 이익을 출력한다.

www.acmicpc.net


문제를 분석하다 가장 먼저 떠오른 풀이는 백트래킹이었다. N이 최대 15이므로 최대 가능한 상담이 2^15개로 계산된다. 그리고 각 경우의 수당 이익을 계산하는 것과 백트래킹 하는 과정은 상수 시간이 소모될 것 같아서 가능하다 판단하고 백트래킹으로 구현을 시작했다.

 

구현이 조금 막막하게 느껴질 수 있는데, 일단 모든 경우의 수를 판단해야 한다. 대충 트리를 만들어보자면 이런 형태를 띤다.

트리대로 구현하면서 가지치기를 해주면 될 것 같다. 이런 느낌의 완전탐색은 재귀로 구현하면 비교적 쉽다. 재귀로 구현하고 재귀식 내부에서 백트래킹을 해주는 걸로 하고 구현해 보자. 재귀를 구현할 땐 아래의 세 부분을 먼저 생각해 두고 구현하는 게 낫다.

  1. 함수 정의
    void func(int cnt) : 현재 cnt번째 상담에 대해 판단하는 함수이다.
  2. base condition
    if(cnt > n) 최대 이익 계산 후 종료
    일반적인 종료 조건은 위와 같이 설정할 수 있다. 그런데, 백트래킹의 한 부분 종료 조건으로 세팅할 수 있을 것 같다. 7일까지 일을 하는데, 5일 차에 3일짜리 상담을 하는 등의 케이스는 종료 조건으로 걸어둘 수 있다. 지금까지 재귀식에 상담을 수행했을 때 현재 며칠차인지 알아야 종료 조건을 설정할 수 있다. cur_day로 두자.
    if(cur_day > n+1) 이 케이스는 유효하지 않은 케이스이므로 그냥 종료, 위의 종료 조건보다 먼저 시행돼야 한다.
  3. 재귀식
    cnt일의 상담을 할 수 있는지 체크하고 상담을 한 경우와 하지 않은 경우로 나눠서 func(cnt+1)을 호출해 주면 된다.
    재귀식을 구현하는 부분도 헷갈려서 트리를 직접 그려보고 설정했다.


이제 설정해 둔 그대로 구현해 주면 된다! 함수 정의 부분을 꼼꼼하게 잘해놔서 구현은 쉬웠다.

/** BackTracking 14501 퇴사 **/
#include <iostream>
#include <vector>
#include <algorithm>
#include <string>
using namespace std;
int n;
int cur_day = 1, ans = 0, mx = 0;
pair<int, int> arr[20];

void func(int cnt)
{
    // base condition 1
    if (cur_day > n + 1)
    {
        return;
    }
    // base condition 2
    if (cnt > n)
    {
        ans = max(mx, ans);
        return;
    }

    // cnt일차 작업에 대해 판단
    // cnt일의 작업을 한다면 다음 작업은 cnt+arr[cnt].second+1부터 시작 가능
    // cnt일의 작업을 할 수 있는지 체크

    // 만약 지금이 2일이라 가정하면 2~7일차 작업 가능
    // 현재 스텝에선 cnt일의 작업을 할 수 있을지 체크
    // 못한다면 백트래킹
    if (cur_day <= cnt)
    {
        // cnt일의 작업 한 경우
        int tmp_day = cur_day;
        cur_day = cnt + arr[cnt].first;
        mx += arr[cnt].second;
        func(cnt + 1);

        // 다시 cnt일의 작업을 안한 경우로 롤백
        cur_day = tmp_day;
        mx -= arr[cnt].second;
        func(cnt + 1);
    }
    else
    {
        // cnt일 할 수 없음
        func(cnt + 1);
    }
}

int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0);

    cin >> n;
    for (int i = 1; i <= n; i++)
        cin >> arr[i].first >> arr[i].second;

    func(1);
    cout << ans << '\n';
}