취미로PS하는사람

[2019.12.25] POI 2005/2006 Stage 2 3번 Subway [BOJ 8128] 본문

PS/Once apon a time

[2019.12.25] POI 2005/2006 Stage 2 3번 Subway [BOJ 8128]

def_win 2021. 12. 20. 18:33

https://www.acmicpc.net/problem/8128

나의 접근 방식은 이러하다.

 

1. 리프 노드 사이로만 길을 놔야 한다. 이는 너무 당연하다.

 

2. 가장 점이 많이 포함되도록 길을 놓았을 때 그 점들의 집합은 끊어져 있지 않다.

만약 끊어져있다면 한 쪽의 한 경로와 다른 쪽의 한 경로의 한 끝점을 맞바꾸면 떨어진 두 컴포넌트 사이의 모든 정점도 포함되기 때문에 언제나 최적의 상태에서 위 조건을 만족하지 않을 수 없다는 것을 알 수 있다.

 

위 두 조건으로부터 최적 정점 집합은 리프 노드가 총 2*k개인 서브그래프라는 것을 알 수 있다. 이 때 포함된 정점 개수를 최대화 해야 한다.

 

3. 최적 점의 집합은 트리의 지름을 포함한다.

몇 번 그려보면 왠지 그럴 것 같다는 느낌을 받을 수 있다. 지름 끝점을 포함하지 않는 집합에서 지름 끝점을 리프노드로 하고 다른 어떤 리프노드가 생기지 않도록 바꾸면 점의 개수가 증가한다는 것을 알 수 있다.

리프노드를 없앨 때 일자 형태로 없애야 한다. 이 길이가 원래 정답으로 가정했던 집합과 연결된 점으로부터 나오는데. 때문에 없애야 하는 길이는 연결점으로부터 집합의 지름 끝점사이의 거리보다 작거나 같다. 그런데 지름 성분이 연결된 곳이 이 연결점보다 한 끝점에서 같거나 큰 거리만큼 떨어져있다면 지름 성분의 길이는 무조건 방금 언급한 리프를 없애기 위한 길이보다 크다. 즉, 이 리프를 없애고 지름성분을 추가하면 점 개수는 같거나 커진다. 만약 그러한 리프가 없다면 지름 성분과의 연결점부터 집합의 지름 끝점까지 없애고 지름성분을 추가하면 된다. 때문에 무조건 지름의 끝점을 포함한다.

 

이제 이를 이용하여 지름의 한 쪽 끝점에서 그리디하게 2*k-1개의 리프를 고르면 된다. 이 때 고르는 방법은 우선 가장 멀리 있는 것을 선택하고, 선택한 정점들로부터 가장 멀리 있는 것을 선택하고.. 를 반복하면 된다. 그런데 이게 너무 어려웠다... 처음에는 트리에서 큰 거 작은 거로 합치는 방식을 썼지만 POI 특유의 메모리 제한 덕분에 MLE를 받고 광광 울었다. 다른 블로그도 참고해보고 생각해보니 현재 정점에서 가장 멀리 있는 정점까지의 거리만 +1해주고 나머지는 거기서 멈추면 되기 때문에 트리 DP?처럼 하면 각 거리에 따라 선택할 수 있는 개수를 구할 수 있다. 중간에 끊겨 있는 길도 있을텐데, 그러한 것들을 선택하기 이전에 그 것과 연결된 길을 선택하는 것이 이득이므로 걱정하지 않아도 된다.

 

코드

더보기
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb emplace_back
#define em emplace
#define all(v) v.begin(), v.end()
 
using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;
 
const int MAX = 1010101;
const int INF = 1 << 30;
const ll LINF = 1LL << 60;
 
vector <int> g[MAX];
int d[MAX], idx[MAX], cnt[MAX];
 
void dfs(int x, int pa) {
    d[x] = d[pa] + 1;
    for(auto i : g[x]) {
        if(i == pa) continue;
        dfs(i, x);
    }
}
 
int ndfs(int x, int pa) {
    int mx = 0;
    for(auto i : g[x]) {
        if(i == pa) continue;
        int temp = ndfs(i, x);
        mx = max(mx, temp);
        cnt[temp]++;
    }
    if(pa != 0) cnt[mx]--;
    return mx + 1;
}
 
int main() {
    ios::sync_with_stdio(false); cin.tie(0);
 
    int n, l;
    cin >> n >> l;
    for(int i = 1; i < n; i++) {
        int u, v;
        cin >> u >> v;
        g[u].eb(v), g[v].eb(u);
    }
    if(l == 0) return !(cout << 0);
    
    int d1 = 0, d2 = 0, ans = 0;
    dfs(1, 0);
    for(int i = 1; i <= n; i++) {
        idx[i] = i;
        if(d[d1] < d[i]) d1 = i;
    }
    ndfs(d1, 0);
    l = 2 * l - 1;
    for(int i = n; i >= 1; i--) {
        if(cnt[i] <= l) ans += i * cnt[i], l -= cnt[i];
        else {
            ans += i * l;
            break;
        }
    }
    cout << ans + 1;
}​
Comments