취미로PS하는사람

[BOJ] Journey to TST set2 본문

PS/BOJ

[BOJ] Journey to TST set2

def_win 2024. 2. 13. 12:57

저번에 이어서 Journey to TST set2를 풀려고 한다.

 

https://psforhobby.tistory.com/57

 

[BOJ] Journey to TST set1

최근에 PS를 다시 하면서 뭘 풀어야 할지 모르겠다는 생각이 들었다. 둘러보다가 고등학교 친구가 만든 인생추천셋이 있길래 차근차근 풀어보려고 한다. https://dsstar.tistory.com/53 Journey to TST 문제

psforhobby.tistory.com

https://dsstar.tistory.com/53

 

Journey to TST 문제 셋 공유

BOJ 그룹(https://www.acmicpc.net/group/16794)에서 2주 동안 총 44(+2) 문제의 연습 셋을 공유했다. 11일 간 매일 선발고사 난이도에 맞추어 4문제씩 셋을 만들었다. 모든 문제에서 얻어갈 내용이 있으며, 내 9

dsstar.tistory.com

02/08의 문제 셋이고, 이번에도 내가 푼 순서대로 리뷰하려고 한다. 이번엔 좀 오래 걸렸다. 일주일 정도 푼 듯.

 

1. Towns and Roads(20640)

https://www.acmicpc.net/problem/20640

역시 세터의 의도답게 교육적인 문제이다. 조금 뻔하긴 해도 꼭 알아야 하는..

 

문제

간선마다 거리가 있는 크기 $N$의 트리가 주어지고 로봇은 초기에 1번 정점에 있다. 3종류의 쿼리를 총 $Q$개 처리해야 한다.

1 $x$: 로봇을 $x$번 정점으로 이동시킨다. 이 때 $x$는 이전 로봇의 정점과 인접한 정점이다.

2 $y$: $y$번 간선을 무효화한다.

3 $z$: $z$번 간선을 유효화한다.

 

각 쿼리 수행 후 가장 멀리 떨어진 정점의 목록을 출력하라.

$1 \leq N, Q \leq 2 \times10^5$이고 답의 총 개수는 $4 \times10^5$ 이하이다.

 

풀이

더보기

1번 정점을 루트로 오일러 투어 테크닉을 쓰고, 초기 세그를 각 정점마다 1번 정점으로부터 거리로 구성하자. 이후 각 쿼리를 다음과 같이 수행하면 항상 세그는 현재 로봇이 존재하는 정점으로부터 각 정점까지의 거리로, 끊어진 경우 음수로 구성되어 있다.

 

1. 1번 정점을 루트로 생각했을 때, 쿼리 행동의 결과 로봇이 올라가거나 내려갈 것이다. 올라간다면 간선의 가중치만큼 현재 정점의 서브트리 전체에 더하고 나머지에 빼준다. 내려간다면 이동하려는 정점에 대해 더하고 빼는 행동을 반대로 해주면 된다.

 

2. 끊는 간선이 연결하는 두 정점 중 자식쪽 서브트리에 현재 로봇이 있다면 그 서브트리를 제외한 정점들에 어떤 큰 상수를 빼고 그렇지 않다면 서브트리 정점들에 큰 상수를 뺀다.

 

3. 2에서 큰 상수를 빼는 대신 더하면 된다. (이렇게 하면 정확하게 복구가 가능하다.)

 

이후 가장 먼 정점의 목록은 해당 구간의 최댓값이 전체 최댓값과 일정 할 때만 내려가는 식으로 해서 찾아주면 된다. 높이가 $O(\log N)$이므로 총 답 개수$\times O(\log N)$밖에 안 걸린다. 쿼리 처리하는 시간복잡도도 동일. (초기 구성은 $O(N\log N)$)

 

구현이 약간 귀찮긴 하다. 그리고 큰 상수를 뺄 때 $10^{18}$같은 너무 큰 수를 빼면 오버플로우 날 수 있다. 주의하기.

 

코드

더보기
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb emplace_back
#define all(v) (v).begin(), (v).end()

using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;

const ll LINF = 1e13;

struct SEG {
    vector <ll> tree, lazy;
    int sz;
    SEG(int n) : sz(n) {
        tree.resize(4*n+10);
        lazy.resize(4*n+10);
    }

    void update_lazy(int node, int s, int e) {
        tree[node] += lazy[node];
        if(s != e) {
            lazy[node<<1] += lazy[node];
            lazy[node<<1|1] += lazy[node];
        }
        lazy[node] = 0;
    }

    void update(int l, int r, ll val) {
        update(1, 1, sz, l, r, val);
    }
    void update(int node, int s, int e, int l, int r, ll val) {
        update_lazy(node, s, e);
        if(s > r || e < l) return;
        if(s >= l && e <= r) {
            lazy[node] += val;
            update_lazy(node, s, e);
            return;
        }
        int m = s + e >> 1;
        update(node<<1, s, m, l, r, val);
        update(node<<1|1, m+1, e, l, r, val);
        tree[node] = max(tree[node<<1], tree[node<<1|1]);
    }

    vector <int> largest(vector <int> &rvs) {
        update_lazy(1, 1, sz);
        ll val = tree[1];
        vector <int> ret;
        get(1, 1, sz, val, ret, rvs);
        sort(all(ret));
        return ret;
    }
    void get(int node, int s, int e, ll val, vector <int> &ret, vector <int> &rvs) {
        if(s == e) {
            if(tree[node] == val) ret.eb(rvs[s]);
            return;
        }
        int m = s + e >> 1;
        update_lazy(node<<1, s, m);
        update_lazy(node<<1|1, m+1, e);
        if(tree[node<<1] == val) get(node<<1, s, m, val, ret, rvs);
        if(tree[node<<1|1] == val) get(node<<1|1, m+1, e, val, ret, rvs);
    }
};

void dfs(int x, int pa, int &num, vector <int> &in, vector <int> &out, vector <vector <pll>> &g, SEG &ST) {
    in[x] = ++num;
    for(auto i : g[x]) {
        if(i.fi == pa) continue;
        dfs(i.fi, x, num, in, out, g, ST);
        ST.update(in[i.fi], out[i.fi], i.se);
    }
    out[x] = num;
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0);

    int n;
    cin >> n;
    vector <vector <pll>> g(n+1);
    vector <pii> E(n+1);
    vector <int> in(n+1), out(n+1), rvs(n+1);
    for(int i = 1; i < n; i++) {
        int u, v; ll d;
        cin >> u >> v >> d;
        g[u].eb(v, d), g[v].eb(u, d);
        E[i] = pii(u, v);
    }
    for(int i = 1; i <= n; i++) {
        sort(all(g[i]));
    }
    SEG ST(n);
    int num = 0;
    dfs(1, 0, num, in, out, g, ST);
    for(int i = 1; i <= n; i++) {
        rvs[in[i]] = i;
    }

    int q;
    int cur = 1;
    cin >> q;
    while(q--) {
        int t, x;
        cin >> t >> x;
        if(t == 1) {
            ll val = lower_bound(all(g[cur]), pll(x, 0))->se;
            if(in[x] < in[cur]) {
                ST.update(in[cur], out[cur], val);
                ST.update(1, in[cur]-1, -val);
                ST.update(out[cur]+1, n, -val);
            }
            else {
                ST.update(in[x], out[x], -val);
                ST.update(1, in[x]-1, val);
                ST.update(out[x]+1, n, val);
            }
            cur = x;
        }
        if(t == 2) {
            int u = E[x].fi, v = E[x].se;
            if(in[u] > in[v]) swap(u, v); // now p[v] = u
            if(in[v] <= in[cur] && out[cur] <= out[v]) {
                ST.update(1, in[v]-1, -LINF);
                ST.update(out[v]+1, n, -LINF);
            }
            else ST.update(in[v], out[v], -LINF);
        }
        if(t == 3) {
            int u = E[x].fi, v = E[x].se;
            if(in[u] > in[v]) swap(u, v); // now p[v] = u
            if(in[v] <= in[cur] && out[cur] <= out[v]) {
                ST.update(1, in[v]-1, LINF);
                ST.update(out[v]+1, n, LINF);
            }
            else ST.update(in[v], out[v], LINF);
        }
        vector <int> ans = ST.largest(rvs);
        cout << ans.size() << ' ';
        for(int i : ans) cout << i << ' ';
        cout << '\n';
    }
}

 

2. Treatment Project(18885)

https://www.acmicpc.net/problem/18855

처음에 잘못 생각해서 거의 다 풀어놓고 오래 고민했다. 쉽게 생각하면 쉬운 문제인 듯.

 

문제

$N$개의 집에 각각 사람이 살고 있고 바이러스에 감염된 상태이다. 총 $M$개의 치료 계획이 있으며 $i$번째 치료 계획은 $T_i$날 저녁에 $[L_i, R_i]$ 구간의 사람들을 치료하며 비용은 $C_i$이다. 하루가 지날 때마다 아침에는 감염되지 않은 사람들 중 인접하며 감염된 사람이 있는 사람들은 다시 감염된다. 모든 사람들을 치료하기 위한 최소 비용을 구하고 없을 경우 $-1$을 출력하라.

$1 \leq N, T_i, C_i \leq 10^9$, $1 \leq M \leq 10^5$.

 

풀이

더보기

일단 $x$축을 사람, $y$축을 시간으로 두고 한 치료 계획을 직선으로 나타낸 뒤 이 치료가 2차원 상에서 어떤 영역을 커버하는지 생각해보자. 45도 기울어진 정사각형 모양임을 알 수 있다. $1$번이나 $N$번을 포함하는 치료 계획은 무한히 큰 것으로 생각할 수 있다.  (과거까지 생각) 2개의 정사각형을 합쳤을 때 그냥 합치면 안되고 양옆 경계선을 보며 추가로 연장선을 그어야 할 수도 있다는 것을 고려해야 하는데 이것만 고려한다면 결국 문제는 주어진 정사각형들을 잘 합쳐서 1번 집이 나타내는 직선으로부터 $N$번 집이 나타내는 직선까지 최소 비용으로 연결하는 문제가 된다.

 

이 모델링을 보면 쉬운 다익스트라 같지만 사각형을 합치며 다른 사각형들과의 연결 가능성이 바뀌어 까다롭다. (그래서 내가 처음에 헤맸다.)

 

생각을 바꿔 어차피 1번부터 빠짐없이 다 채워야 하니 1번 집부터 출발한다고 생각을 해보자. 편의를 위해 반시계 방향으로 45도 적당히 돌릴 것이다. (정확히는 두 좌표를 더하고 빼서 변환해야 정수로 계산 가능) 이제 각 구간은 우상향하는 선분들로 바뀌며, 여기서 왼쪽 아래 점을 시작점, 오른쪽 위 점을 끝점이라고 하겠다.

 

1번부터 차례대로 모두 치료한다고 생각을 해보면, 현재 구간 다음으로 고를 수 있는 구간은 45도 적당히 돌린 그래프에서 현재 구간의 끝점 기준 3사분면에 출발점이 있는 것들만 가능하다.(정확히는 구간끼리 1 차이를 허용하기 때문에 끝점 좌표에 1씩 더한 점 기준) 

 

이제 비용이 모두 양수인 것으로부터 다익스트라의 철학을 이용하면, 결국 현재 탐색한 구간들 중 도달 비용이 최소인 구간 다음으로 가능한 구간들은 최적비용이 (현재 구간 도달 비용) + (해당 구간 비용)이 된다. 그보다 비용이 작은 경우는 이후에 나올 수 없기 때문이다. 이렇게 모든 구간을 탐색한 후 $N$번을 채울 수 있는 구간들에 대해 도달 비용의 최소를 출력하면 된다. (그런 구간에 도달 못했으면 -1)

 

어떤 구간의 다음 구간 후보들의 경우 좌표압축+세그 비스무리한 것으로 각각 $O(\log M)$만에 추려낼 수 있다. (코드 참고) 다만 각 구간이 중복으로 꺼내질 수 있는데 적당히 체크해서 이미 꺼냈으면 무시해야 한다. 난 $x$좌표에 대해 $y$좌표를 내림차순으로 넣고 뒤에서부터 꺼내는 식으로 구현했다. (다음 구간 후보들을 모두 꺼내야 되므로 $x$좌표의 일정 범위에서 특정 값보다 $y$가 작은 모든 점을 꺼내야 해서 이렇게 했다.)

 

풀고 나서 알고리즘 태그를 보니 머지소트 트리로 보는 게 맞는 것 같기도 하고...

 

종합하면 총 시간 $O(M\log M)$에 답을 해결할 수 있게 된다. 

 

+) 그림을 그려가며 이해하면 쉽다. (그림을 첨부하면 좋겠지만 글쓴이가 귀찮은 관계로 pass)

 

코드

더보기

 

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb emplace_back
#define em emplace
#define all(v) v.begin(), v.end()

using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;

const int INF = 1e9;
const ll LINF = 1e18;

struct SEG {
    vector <vector <pii>> tree;
    int sz;
    SEG(int n) {
        for(sz = 1; sz < n; sz <<= 1) ;
        tree.resize(sz<<1);
    }

    void update(int x, pii val) {
        int node = x + sz;
        tree[node].eb(val);
        while(node >>= 1) tree[node].eb(val);
    }

    vector <int> get(int x, int y) {
        vector <int> ret;
        for(int l = sz, r = x + sz; l <= r; l >>= 1, r >>= 1) {
            if(~r & 1) {
                while(!tree[r].empty() && tree[r].back().fi <= y) {
                    ret.eb(tree[r].back().se);
                    tree[r].pop_back();
                }
                r--;
            }
        }
        return ret;
    }
};

struct Line {
    int x, y, nx, ny, c;
    ll ans = LINF;
    bool st = false, ed = false;
};

int main() {
    ios::sync_with_stdio(false); cin.tie(0);

    int n, m;
    cin >> n >> m;
    vector <Line> v(m);
    vector <int> X;
    for(int i = 0; i < m; i++) {
        int t, l, r;
        cin >> t >> l >> r >> v[i].c;
        v[i].x = l - t, v[i].y = l + t, v[i].nx = r + 1 - t, v[i].ny = r + 1 + t;
        if(l == 1) v[i].st = true, v[i].ans = v[i].c;
        if(r == n) v[i].ed = true;
        X.eb(v[i].x), X.eb(v[i].nx);
    }
    
    sort(all(X));
    X.erase(unique(all(X)), X.end());
    sort(all(v), [](Line a, Line b) {
        return a.y > b.y;
    });

    SEG ST(X.size());
    priority_queue <pll, vector <pll>, greater <pll>> pq;
    for(int i = 0; i < m; i++) {
        v[i].x = lower_bound(all(X), v[i].x) - X.begin();
        v[i].nx = lower_bound(all(X), v[i].nx) - X.begin();
        if(v[i].st) pq.em(v[i].ans, i);
        else ST.update(v[i].x, pii(v[i].y, i));
    }

    while(!pq.empty()) {
        int cur = pq.top().se;
        pq.pop();
        vector <int> next = ST.get(v[cur].nx, v[cur].ny);
        for(int i : next) {
            if(v[i].ans < LINF) continue;
            v[i].ans = v[cur].ans + v[i].c;
            pq.em(v[i].ans, i);
        }
    }

    ll ans = LINF;
    for(int i = 0; i < m; i++) {
        if(v[i].ed) ans = min(ans, v[i].ans);
    }
    if(ans == LINF) cout << "-1\n";
    else cout << ans << '\n';
}

 

3. The Firm Knapsack Problem(21644)

https://www.acmicpc.net/problem/21644

난 이게 2번보다 훨씬 어렵다고 느끼긴 했다.. 도저히 모르겠어서 답지 까고 풀었음.

 

문제

일반적인 $N$개의 보석(무게 $w_i$, 비용 $c_i$)과 $W$의 용량이 주어진 냅색 문제에서, 최적해보다 비용 합을 크거나 같게 하며 용량이 $\frac{3}{2}W$ 이하가 되도록 하는 보석 조합을 출력하라.

$1 \leq N \leq 10^5$, $1 \leq W \leq 10^{12}$

 

풀이

더보기

우선 $\frac{W}{2}$보다 큰 보석이 최대 1개밖에 들어갈 수 없음을 관찰하자. 이제 각 보석들을 $\frac{W}{2}$보다 큰 것들과 아닌 것들로 나누자.(다만 무게가 $W$보다 큰 보석들은 버린다.) $\frac{W}{2}$보다 무게가 큰 보석들 각각을 고른 경우(또는 고르지 않은 경우)에 대해 $\frac{W}{2}$보다 무게가 작은 보석들을 적당히 고를 것이다.

 

우리가 실제 최적해와 동일하게 $\frac{W}{2}$보다 무게가 큰 보석을 골랐다고 해보자. (또는 고르지 않았다고 해보자.) 앞으로 고를 보석들은 무게가 $\frac{W}{2}$ 이하이다. 이제 다음 사실을 증명할 것이다.

 

남은 보석들 중 $\frac{cost_i}{w_i}$가 큰 순서대로 되는만큼 고르면 최적해보다 비용 합이 크거나 같다.

 

증명) 최적해 집합을 $opt$라고 하고, $j \in opt$가 위에서 고르지 않은 보석 중 가장 $\frac{cost}{w}$가 큰 보석이라 하자. $w_j \leq \frac{W}{2}$이므로 현재 고른 총 무게를 $W_{cur}$이라 할 때 $W_{cur} + w_j > \frac{3}{2}W$, 즉 $W_{cur}>W$가 된다. 현재 골라진 모든 보석들의 무게당 비용은 최적해에서 고르지 않은 모든 보석들의 $\frac{cost}{w}$보다 크고, 총 고른 비용도 $W$보다 크기 때문에 최적해보다 현재 고른 보석들의 비용 합이 더 크다는 것을 알 수 있다. 따라서 이렇게 고른 경우 언제나 최적해보다 비용 합이 크거나 같게 된다.

 

코드

더보기
#include <bits/stdc++.h>
#define fi first
#define se second
#define eb emplace_back
#define em emplace
#define all(v) v.begin(), v.end()

using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;

const int INF = 1e9;
const ll LINF = 1e18;

struct G {
    ll w, c; int idx;
    G(ll _w, ll _c, int _idx) : w(_w), c(_c), idx(_idx) {}
};

void solve() {
    int n; ll w;
    cin >> n >> w;
    vector <G> b, s;
    for(int i = 0; i < n; i++) {
        ll tw, tc; cin >> tw >> tc;
        if(tw <= w / 2) s.eb(tw, tc, i + 1);
        else if(tw <= w) b.eb(tw, tc, i + 1);
    }
    sort(all(b), [](G x, G y) {
        return x.w < y.w;
    });
    sort(all(s), [](G x, G y) {
        if(x.c * y.w == y.c * x.w) return x.c > y.c;
        return x.c * y.w > y.c * x.w;
    });

    w = (3 * w) / 2;
    int idx = 0, aidx = -1;
    ll ans = 0, sum = 0, csum = 0;
    while(idx < s.size() && sum + s[idx].w <= w) {
        sum += s[idx].w;
        csum += s[idx].c;
        idx++;
    }
    ans = csum;
    for(int i = 0; i < b.size(); i++) {
        while(idx > 0 && sum + b[i].w > w) {
            idx--;
            sum -= s[idx].w;
            csum -= s[idx].c;
        }
        if(ans < csum + b[i].c) {
            ans = csum + b[i].c;
            aidx = i;
        }
    }
    
    if(aidx < 0) {
        while(idx < s.size() && sum + s[idx].w <= w) {
            sum += s[idx].w;
            idx++;
        }
        cout << idx << '\n';
        for(int i = 0; i < idx; i++) {
            cout << s[i].idx << ' ';
        }
        cout << '\n';
    }
    else {
        while(idx < s.size() && sum + s[idx].w + b[aidx].w <= w) {
            sum += s[idx].w;
            idx++;
        }
        cout << idx + 1 << '\n';
        cout << b[aidx].idx << ' ';
        for(int i = 0; i < idx; i++) {
            cout << s[i].idx << ' ';
        }
        cout << '\n';
    }
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0);

    int q;
    cin >> q;
    while(q--) {
        solve();
    }
}

 

4. 트리와 쿼리 14(18372)

https://www.acmicpc.net/problem/18372

옛날에 어디서 본 적 있는 문제라고 생각했는데 친구한테 물어보니 계절학교에 나왔던 문제라고 한다. 처음에 잘못 풀었다가 친구에게 2개를 합치면 된다는 힌트를 받고 풀어내었다. 풀고 보니 쉬운 문제로 느껴지는데 아이디어를 잘못 잡기가 쉬워서 어려운 듯.

 

문제

정점이 $N$개 있고 간선 가중치가 모두 1인 트리가 주어진다. 다음 쿼리를 수행해야 한다.

$k v_1 r_1 \cdots v_k r_k$: 정점 $v_i$로부터 거리 $r_i$ 이내에 있다는 것을 $i$번 조건을 만족한다고 할 때, 주어진 $k$개의 조건 중 $k-1$개 이상의 조건을 만족하는 정점의 개수를 출력하라.

$1 \leq N \leq 10^5$, $\Sigma k \leq 3 \times 10^5$

 

풀이

더보기

정점들을 잇는 간선 사이에 정점들을 추가해서 총 $2N-1$ 크기의 트리를 새로 구성하자. 이제 주어진 조건들을 적당히 바꿔 이 트리에서 생각한다면 두 조건의 교집합은 결국 새로운 하나의 조건으로 표현할 수 있음을 알 수 있다. (하나가 다른 것에 포함되는 경우, 교집합이 없는 경우, 둘 다 아닌 경우로 나눠 생각해보면 자명하다.)

 

이제 포함과 배제를 할 것이다. $k$개의 조건에 대해 누적으로 오른쪽/왼쪽부터 합친 결과를 구했다면 $i$번째 조건을 제외한 조건을 구할 수 있다. ($i$번째를 제외한 조건들을 만족하는 정점 수)를 모두 더하고 (모든 조건들을 만족하는 정점 수) $\times (k-1)$을 빼주면 답이라는 것을 알 수 있다.

 

이제 한 조건을 만족하는 기존 트리의 정점 개수를 빠르게 세기만 하면 된다. 이는 센트로이드 디컴퍼지션으로 가능하다. 각 센트로이드에 대해 그 센트로이드가 관할하는 영역(?)에서 거리별로 기존 트리 정점 개수를 누적합(총 누적합이라 하겠다)으로 구해주고 또 센트로이드로 분할되는 각 서브트리들에 대해서도 구해주자(개별 누적합이라 하겠다). 만약 현재 센트로이드 정점이 조건의 정점과 동일하다면 그냥 총누적합을 더해주기만 하면 된다. 아니라면 조건의 정점이 속한 서브트리를 알아내고, 조건의 정점과 현재 센트로이드 정점 사이의 거리를 구해주자. 이 거리가 조건의 거리보다 크거나 같을 경우 이 거리에 해당하는 총누적합에서 조건의 정점이 포함된 서브트리의 개별 누적합을 뺀 것을 더해준 뒤 그 서브트리로 내려가서 동일한 과정을 반복하면 답이 구해진다. 따라서 $O(\log N)$만에 구할 수 있다.

 

합칠 때 lca를 이용해야 되고 센트로이드 디컴퍼지션에 $O(N\log N)$이 필요해 총 시간복잡도는 $O((N+\Sigma k)\log N)$이다.

 

센트로이드 디컴퍼지션이라 구현을 열심히 잘 해야 된다. 근데 오랜만에 했는데도 생각보다 깔끔하게 잘 한 듯.

이 문제 역시 그림으로 그려가며 생각하면 이해가 쉬울 것이다.

 

코드

더보기

 

#include <bits/stdc++.h>
#define fi first
#define se second
#define eb emplace_back
#define em emplace
#define all(v) v.begin(), v.end()

using namespace std;
typedef long long ll;
typedef pair <int, int> pii;
typedef pair <ll, ll> pll;

const int MAX = 202020;
const int INF = 1e9;
const ll LINF = 1e18;

int sz[MAX], p[18][MAX], in_ct[20][MAX], d_ct[18][MAX], dep[MAX], rt;
bool chk[MAX], nchk[MAX];
vector <int> g[MAX];
vector <int> dcnt[MAX], dcnt_tot[MAX];

void dfs(int x, int pa) {
    sz[x] = 1;
    p[0][x] = pa;
    dep[x] = dep[pa] + 1;
    for(int i = 1; i < 18; i++) {
        p[i][x] = p[i-1][p[i-1][x]];
    }
    for(int i : g[x]) {
        if(i == pa) continue;
        dfs(i, x);
        sz[x] += sz[i];
    }
}

int find_cen(int x) {
    int csz = 1, mx = 0, y = 0;
    for(auto i : g[x]) { 
        if(chk[i]) continue;
        csz += sz[i];
        if(sz[i] > mx) {
            mx = sz[i];
            y = i;
        }
    }
    if(mx <= csz / 2) return x;
    sz[x] = csz - mx;
    return find_cen(y);
}

void ndfs(int x, int &pa, int &cen, int &y, int &lv, int d) {
    in_ct[lv][x] = y;
    d_ct[lv][x] = d;
    if(dcnt[y].size() == d) dcnt[y].eb(0);
    if(dcnt_tot[cen].size() == d) dcnt_tot[cen].eb(0);
    if(x & 1) {
        dcnt[y][d]++;
        dcnt_tot[cen][d]++;
    }
    for(int i : g[x]) {
        if(chk[i] || i == pa) continue;
        ndfs(i, x, cen, y, lv, d + 1);
    }
}

int make_ct(int x, int lv = 0) {
    x = find_cen(x);
    chk[x] = true;
    dcnt_tot[x].eb(x & 1);
    for(auto i : g[x]) {
        if(chk[i]) continue;
        int y = make_ct(i, lv + 1);
        dcnt[y].eb(0);
        ndfs(i, x, x, y, lv, 1);
        for(int j = 1; j < dcnt[y].size(); j++) {
            dcnt[y][j] += dcnt[y][j-1];
        }
    }
    for(int j = 1; j < dcnt_tot[x].size(); j++) {
        dcnt_tot[x][j] += dcnt_tot[x][j-1];
    }
    chk[x] = false;
    return x;
}

int query(int x, pii val, int lv = 0) {
    if(x == val.fi) return dcnt_tot[x][min((int)dcnt_tot[x].size()-1, val.se)];
    int y = in_ct[lv][val.fi];
    int ans = 0, td = val.se - d_ct[lv][val.fi];
    if(td >= 0) {
        ans = dcnt_tot[x][min((int)dcnt_tot[x].size()-1,td)]
                 - dcnt[y][min((int)dcnt[y].size()-1, td)];
    }
    return ans + query(y, val, lv + 1);
}
int query(pii val) {
    if(val.se < 0) return 0;
    return query(rt, val);
}

int lca(int u, int v) {
    if(dep[u] < dep[v]) swap(u, v);
    for(int i = 17; i >= 0; i--) {
        if(dep[u] - dep[v] >= (1 << i)) u = p[i][u];
    }
    if(u == v) return u;
    for(int i = 17; i >= 0; i--) {
        if(p[i][u] != p[i][v]) u = p[i][u], v = p[i][v];
    }
    return p[0][u];
}

int kth(int x, int k) {
    for(int i = 0; i < 18; i++) {
        if(k >> i & 1) x = p[i][x];
    }
    return x;
}

pii merge(pii x, pii y) {
    int l = lca(x.fi, y.fi);
    int dist = dep[x.fi] + dep[y.fi] - 2 * dep[l];
    if(dist + x.se <= y.se) return x;
    if(dist + y.se <= x.se) return y;
    if(y.se + x.se < dist) return pii(x.fi, -INF);
    int xd = (x.se + dist - y.se) / 2;
    int cen;
    if(xd <= dep[x.fi] - dep[l]) cen = kth(x.fi, xd);
    else cen = kth(y.fi, dist - xd);
    return pii(cen, (x.se + y.se - dist) / 2);
}

int main() {
    ios::sync_with_stdio(false); cin.tie(0);
    
    int n;
    cin >> n;
    for(int i = 1; i < n; i++) {
        int u, v; cin >> u >> v;
        u = 2 * u - 1, v = 2 * v - 1;
        g[u].eb(2*i);
        g[2*i].eb(u);
        g[2*i].eb(v);
        g[v].eb(2*i);
    }
    
    dfs(1, 0);
    rt = make_ct(1);
    
    int q;
    cin >> q;
    while(q--) {
        int k;
        cin >> k;
        vector <pii> v(k), L(k), R(k);
        for(int i = 0; i < k; i++) {
            cin >> v[i].fi >> v[i].se;
            v[i].fi = 2 * v[i].fi - 1;
            v[i].se *= 2;
        }
        if(k == 1) {
            cout << n << '\n';
            continue;
        }
        L[0] = v[0], R[k-1] = v[k-1];
        for(int i = 1; i < k; i++) {
            L[i] = merge(L[i-1], v[i]);
            R[k-i-1] = merge(R[k-i], v[k-i-1]);
        }
        ll ans = query(L[k-2]) + query(R[1]);
        for(int i = 1; i + 1 < k; i++) {
            ans += query(merge(L[i-1], R[i+1]));
        }
        ans -= (ll)(k - 1) * query(R[0]);
        cout << ans << '\n';
    }
    
    return 0;
}

 

'PS > BOJ' 카테고리의 다른 글

[BOJ] Journey to TST set4  (0) 2024.03.26
[BOJ] Journey to TST set3  (0) 2024.03.02
[BOJ] Journey to TST set1  (0) 2024.02.04
[BOJ 17694] Sparklers  (0) 2022.01.27
[BOJ 9022] Stains  (0) 2022.01.26
Comments