개발일지

Algorithm in A..Z - 백준 3830 교수님은 기다리지 않는다 본문

Problem Solving

Algorithm in A..Z - 백준 3830 교수님은 기다리지 않는다

강태종 2021. 4. 20. 20:30

 

www.acmicpc.net/problem/3830

 

3830번: 교수님은 기다리지 않는다

교수님의 질문 (? a b)이 입력으로 들어올 때 마다, 지금까지 측정한 결과를 바탕으로 a와 b의 무게 차이를 계산할 수 있다면, b가 a보다 얼마나 무거운지를 출력한다. 무게의 차이의 절댓값이 1,000,

www.acmicpc.net


접근

처음에 Disjoint-Set과 LCA를 고민했다.

Disjoint-Set을 고민한 이유는 간선을 추가할(샘플을 측정) 때 Disjoint-Set의 Union연산을 통해 하나의 Set으로 만들고, Find연산으로 쉽게 UNKNOWN인지 알 수 있기 때문이다.

LCA를 고민한 이유는 무게차이를 구할 때 공통 조상 노드를 찾아서 거리를 구하면 logN만에 쉽게 거리를 찾을 수 있기 때문이다.

하지만 LCA를 간선을 추가할(샘플을 측정) 때마다 업데이트 할 수 없기 때문에 아닌거 같았고(LCA를 초기화 할 때 DFS를 수행하기 때문에 간선을 추가할 때마다 DFS를 돌리면 시간초과) Disjoint-Set을 응용해서 풀기로 생각했다.


풀이

Disjoint-Set을 사용하여 UNKNOWN을 구분할 수 있다. (같은 Set이 아니면 Unknown)


거리는 distance라는 배열을 선언하고 Distance에 값은 각 인덱스와 Root의 거리를 저장하며 a와 b의 거리는 distance[a] - distance[b]로 구할 수 있고 distance배열은 Find연산과 Union연산에서 업데이트가 된다.

Find연산에서 B와 Root의 거리는 X + Y이고 X는 distance[disjoint[b]]로 구할 수 있다.

: 재귀함수를 통해 distance[disjoint[b]]를 구할 때 B의 부모 A는 이미 업데이트 된 상태이고 A가 부모를 가진다고 해도 X는 A와 Root와의 거리가 된다.

Union연산에서 b가 속한 Set을 a가 속한 Set에 합칠 때 위와 같은 과정이 발생하며 distance[B] = distance[a] - distance[b] + w(a와 b의 거리)로 구할 수 있다.


코드

C++

#include <bits/stdc++.h>
using namespace std;

vector<int> disjoint;
vector<long long> dist;

int find(int index) {
    if (disjoint[index] < 0) {
        return index;
    }

    int parent = find(disjoint[index]);

    dist[index] += dist[disjoint[index]];
    return disjoint[index] = parent;
}

void merge(int i, int j, long long dis) {
    int ii = find(i);
    int jj = find(j);

    if (ii == jj) {
        return;
    }

    disjoint[ii] += disjoint[jj];
    disjoint[jj] = ii;
    dist[jj] = dist[i] - dist[j] + dis ;
}


int main() {
    ios::sync_with_stdio(false); cin.tie(nullptr); cout.tie(nullptr);

    while (true) {
        int n, m;
        cin >> n >> m;
        if (n == 0 && m == 0) {
            break;
        }

        disjoint.assign(n + 1, -1);
        dist.assign(n + 1, 0LL);

        while (m--) {
            char op;
            cin >> op;
            if (op == '!') {
                int i, j;
                long long dis;
                cin >> i >> j >> dis;

                merge(i, j, dis);
            } else {
                int i, j;
                cin >> i >> j;

                int ii = find(i);
                int jj = find(j);

                if (ii == jj) {
                    cout << dist[j] - dist[i] << "\n";
                } else {
                    cout << "UNKNOWN" << "\n";
                }
            }
        }
    }
}

Kotlin

import java.util.*

lateinit var disjoint: IntArray
lateinit var dist: LongArray

fun find(index: Int): Int {
    if (disjoint[index] < 0) {
        return index
    }

    val root = find(disjoint[index])

    dist[index] += dist[disjoint[index]]
    disjoint[index] = root
    return root
}

fun merge(i: Int, j: Int, w: Long) {
    val ii = find(i)
    val jj = find(j)

    if (ii == jj) {
        return
    }

    disjoint[ii] += disjoint[jj]
    disjoint[jj] = ii

    dist[jj] = dist[i] - dist[j] + w
}

fun main() {
    System.`in`.bufferedReader().use { br ->
        System.out.bufferedWriter().use { bw ->
            while (true) {
                val (n, m) = br.readLine().split(" ").map { it.toInt() }
                if (n == 0 && m == 0) {
                    break
                }

                disjoint = IntArray(n + 1) { -1 }
                dist = LongArray(n + 1) { 0L }
                repeat(m) {
                    val st = StringTokenizer(br.readLine())
                    when(st.nextToken()) {
                        "!" -> {
                            val i = st.nextToken().toInt()
                            val j = st.nextToken().toInt()
                            val dis = st.nextToken().toLong()

                            merge(i, j, dis)
                        }
                        "?" -> {
                            val i = st.nextToken().toInt()
                            val j = st.nextToken().toInt()
                            val ii = find(i)
                            val jj = find(j)

                            if (ii == jj) {
                                bw.appendLine((dist[j] - dist[i]).toString())
                            } else {
                                bw.appendLine("UNKNOWN")
                            }
                        }
                    }
                }
            }
        }
    }
}

 

Comments