본문 바로가기
알고리즘/자료구조

[백준 2042] 구간 합 구하기 / 자바 / 자료구조(세그먼트 트리)

by 순원이 2024. 8. 7.

#문제         

레벨: G1
알고리즘: 자료구조(세그먼트 트리) 

풀이시간: 1시간
힌트 참조 유무: 유

https://www.acmicpc.net/problem/2042


#문제 풀이        

문제를 처음 봤을 때 누적합을 떠올렸다. 앞서 말하지만 누적합은 정답이 아니다. 배열의 1 ~i 까지의 합을 적어 놓은 것이다. 그러면 A ~ B 까지의 합을 구할 때는 A ~ B 까지 순회하는 것이 아닌 arr[B] - arr[A]를 하면 되는 것이다. 그러나 여기서 문제인 것은 값을 바꿀 때이다. 만약 2 번째 숫자를 2 -> 5를 바꿨다면 2 번째 뒤에 있는 값들은 다 업데이트 해야 한다. 왜냐하면 1 ~ i 까지의 숫자합이기 때문에 새로 바뀐 2의 숫자로 더해줘야 하기 때문이다. 답을 출력하는 것 O(1)이지만, 값을 업데이트 하는 건 O(n)이다. 최대 O(N x M) 이기 때문에 O(10,000,000,000) = 백억이므로 실패다. 

이 문제에서 알아야 할 것은 세그먼트 트리이다. 

세그먼트 트리에 대해서 모른다면 아래 글을 보고 오기 바란다. 이 문제는 세그먼트 트리를 구현하기만 하면 되는 문제이기 때문이다.
https://blog.naver.com/ndb796/221282210534

 

41. 세그먼트 트리(Segment Tree)

이번 시간에 다룰 내용은 여러 개의 데이터가 연속적으로 존재할 때 특정한 범위의 데이터의 합을 구하는 ...

blog.naver.com

난 처음 코드를 봤을 때 이해가 안됐었다. 노드 번호, start, end가 무슨 연관관계가 있지? 위 글을 보고 왔다면 어느 정도 이해는 했겠지만, 내가 쉽게 설명해보겠다. 세그먼트 트리의 노드번호는 무시해도 된다. 단 이걸 알고 있어야 한다. 왼쪽 자식 노드의 노드 번호는 해당 노드의 *2이고 오른쪽 자식 노드의 노드 번호는 해당 노드의 *2+1이다. 해당 노드의 값이 배열 0 ~ 5까지의 합이라면  왼쪽 자식 노드의 값은 0~ 2의 합이고 오른쪽 자식 노드는 3 ~ 5의 합이다. 더 극단적으로 쉽게 이야기 하자면 노드 번호는 자식 노드로 내려가기 위해 존재하는 거다.

또 하나 말하자면 sum함수의 두 번째 if문에 대해서 설명하겠다. 6개의 숫자 배열이 있고 2(3번째 숫자) ~ 5(6번째 숫자) 의 배열을 더해야 한다면, (2~2) (3~5)가 선택 되는 것이다. 

빨간색: A~B까지의 합을 나타낸다.

파란색: 노드번호


#풀이 코드      

import java.io.*;
import java.util.*;

public class Main {
    static long[] tree;
    static int N;

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out));
        StringTokenizer st = new StringTokenizer(br.readLine());

        N = Integer.parseInt(st.nextToken());
        int M = Integer.parseInt(st.nextToken());
        int K = Integer.parseInt(st.nextToken());

        // 세그먼트 트리의 크기를 계산
        int h = (int) Math.ceil(Math.log(N) / Math.log(2));
        int treeSize = (1 << (h + 1));
        tree = new long[treeSize];

        // 초기 배열 입력 받기
        for (int i = 0; i < N; i++) {
            long num = Long.parseLong(br.readLine());
            update(1, 0, N - 1, i, num);
        }

        // 쿼리 처리
        for (int i = 0; i < M + K; i++) {
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken());
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());

            if (a == 1) {
                long diff = c - query(1, 0, N - 1, b - 1, b - 1);
                update(1, 0, N - 1, b - 1, diff);
            } else {
                long sum = sum(1, 0, N - 1, b - 1, (int)c - 1);
                bw.write(sum + "\n");
            }
        }

        bw.flush();
        bw.close();
        br.close();
    }

    // 세그먼트 트리 업데이트
    // 해당 node는 start ~ end까지의 합을 의미
    static void update(int node, int start, int end, int index, long diff) {
        if (index < start || index > end) return;

        tree[node] += diff;

        if (start != end) {
            int mid = (start + end) / 2;
            update(node * 2, start, mid, index, diff);
            update(node * 2 + 1, mid + 1, end, index, diff);
        }
    }

    // 구간 합 쿼리
    // 해당 node는 start ~ end까지의 합을 의미
    static long sum(int node, int start, int end, int left, int right) {
        if (left > end || right < start) return 0;
        if (left <= start && end <= right) return tree[node];

        int mid = (start + end) / 2;
        return sum(node * 2, start, mid, left, right) + 
               sum(node * 2 + 1, mid + 1, end, left, right);
    }
}

#비슷한 유형      

 

최솟값 최대값 

https://www.acmicpc.net/problem/2357

최솟값

https://www.acmicpc.net/problem/10868

구간 곱 구하기

https://www.acmicpc.net/problem/11505