[ STUDY ]/CodingTest

[ 트리 ] 세그먼트 트리(백준 2042)

김강니 2024. 11. 14. 01:52

 세그먼트 트리

주어진 데이터의 구간 합과 데이터 업데이트를 빠르게 수행하기 위해 고안해낸 자료구조이다.

더 큰 범위로 '인덱스 트리'라고 불린다.

 

💡 세그먼트 트리의 핵심 이론

세그먼트 트리의 종류는 구간 합, 최대 · 최소 구하기로 나눌 수 있다.

 

1. 트리 초기화하기

리프 노드의 개수가 데이터의 개수(N)이상이 되도록 트리 배열을 만든다.

트리 배열 크기 구하기 : 2k ≥ N을 만족하는 k의 최솟값을 구하고 2k x 2를 트리 배열의 크기로 정의한다.

그리고 주어진 데이터들을 2k에 해당하는 인덱스부터 채운다.

 

만약, 8개의 데이터가 주어졌을 때는 {5, 8, 4, 3, 7, 2, 1, 6}

2의 제곱수 중 처음으로 8보다 크거나 같아지는 제곱수를 구한다 -> k=3

배열의 크기는 23 x 2 = 16

배열을 채우는 시작 인덱스는 23 = 8

아래 사진과 같이 배열을 채워준다.

  • 구간 합의 경우 자신의 부모인덱스에 자신의 값을 더해준다.
  • 최대의 경우 자신의 부모 인덱스 값이 자신보다 작으면 자신이 가지고 있는 값으로 대체한다.
  • 최소의 경우 자신의 부모 인덱스 값이 자신보다 크면 자신이 가지고 있는 값으로 대체한다.

 

 

2. 질의값 구하기

주어진 질의 인덱스를 세그먼트 트리의 리프 노드에 해당하는 인덱스로 변경한다.

세그먼트 트리 index = 주어진 질의 index + 2k -1

ex) 1-4까지 구간합을 구해주세요. -> 우리는 인덱스 8-11까지 구간 합을 구해야한다.

 

💡 질의값 구하는 과정

  1. start_idx % 2 == 1일 때 해당 노드를 독립노드로 선택한다. -> 시작 노드가 오른쪽 자식이라는 뜻
  2. end_idx % 2 == 0일 때 해당 노드를 독립노드로 선택한다. -> 끝 노드가 왼쪽 자식이라는 뜻
  3. start_idx depth 변경 : start_idx = (start_idx + 1) / 2 연산을 실행한다.
  4. end_idx depth 변경 : end_idx = (end_idx - 1) / 2 연산을 실행한다.
  5. 1-4를 반복하다가 end_idx < start_idx가 되면 종료한다.

 

3. 데이터 업데이트 하기

트리의 특성을 살려 부모노드로(/2) 가면서 올라가면서 데이터 업데이트

 

 

💡 구간 합 구하기 예제 (2~6 구간 합 구하기)

1. 리프 노드의 인덱스로 변경한다. 

start_idx = 2+8-1 = 9
end_idx = 6+8-1 = 13

 

 

2. 부모 노드로 이동한다.

start_idx % 2 = 9 % 2 = 1   -> 노드 선택(start는 1일때 선택)
end_idx % 2 = 13 % 2 = 1   -> 노드 미선택(end는 0일때 선택)
start_idx = (start_idx + 1) / 2 = 10 / 2 = 5
end_idx = (end_idx - 1) / 2 = 12 / 2 = 6

 

3. 한번 더 부모 노드로 이동한다.

start_idx % 2 = 5 % 2 = 1   -> 노드 선택(start는 1일때 선택)
end_idx % 2 = 6 % 2 = 0   -> 노드 선택(end는 0일때 선택)
start_idx = (start_idx + 1) / 2 = 6 / 2 = 3
end_idx = (end_idx - 1) / 2 = 5 / 2 = 2
-> start_idx와 end_idx가 교차됨. 이제 독립노드들을 더해준다.

 

 

 2042 : 구간 합 구하기

문제

 

어떤 N개의 수가 주어져 있다. 그런데 중간에 수의 변경이 빈번히 일어나고 그 중간에 어떤 부분의 합을 구하려 한다. 만약에 1,2,3,4,5 라는 수가 있고, 3번째 수를 6으로 바꾸고 2번째부터 5번째까지 합을 구하라고 한다면 17을 출력하면 되는 것이다. 그리고 그 상태에서 다섯 번째 수를 2로 바꾸고 3번째부터 5번째까지 합을 구하라고 한다면 12가 될 것이다.

 

입력

첫째 줄에 수의 개수 N(1 ≤ N ≤ 1,000,000)과 M(1 ≤ M ≤ 10,000), K(1 ≤ K ≤ 10,000) 가 주어진다. M은 수의 변경이 일어나는 횟수이고, K는 구간의 합을 구하는 횟수이다. 그리고 둘째 줄부터 N+1번째 줄까지 N개의 수가 주어진다. 그리고 N+2번째 줄부터 N+M+K+1번째 줄까지 세 개의 정수 a, b, c가 주어지는데, a가 1인 경우 b(1 ≤ b ≤ N)번째 수를 c로 바꾸고 a가 2인 경우에는 b(1 ≤ b ≤ N)번째 수부터 c(b ≤ c ≤ N)번째 수까지의 합을 구하여 출력하면 된다.

입력으로 주어지는 모든 수는 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.

 

출력

첫째 줄부터 K줄에 걸쳐 구한 구간의 합을 출력한다. 단, 정답은 -263보다 크거나 같고, 263-1보다 작거나 같은 정수이다.


입력 출력
5 2 2
1
2
3
4
5
1 3 6
2 2 5
1 5 2
2 3 5
17
12

 

문제 풀이

  • N = 수의 개수, M = 수의 변경이 일어나는 횟수, K = 구간 합을 구하는 횟수
  • 입력 -> 수의 개수, 수의 개수만큼 데이터 받기, M+K만큼 판별(a=1: 업데이트, a=2: 구간 합)

  • 미친거...전부 다 long형으로 써야함!ㅋㅋ 테케 다 되는데 틀렷다고 난리난리......

 

실행 코드

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.Arrays;
import java.util.StringTokenizer;

public class P2042_indextree {
    static long[] tree;
    static long N, M, K;
    static long minK, square; //2^k
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        StringTokenizer st = new StringTokenizer(br.readLine());

        N = Integer.parseInt(st.nextToken()); //데이터 갯수
        M = Integer.parseInt(st.nextToken()); //업데이트 횟수
        K = Integer.parseInt(st.nextToken()); //구간합 횟수
        minK = (int) Math.ceil(Math.log(N) / Math.log(2)); // 2^k ≥ N을 만족하는 최소값
        square = (int) Math.pow(2, minK); // 2^k
        tree = new long[(int) (square*2)]; //이진 트리 배열 , 배열의 크기 = 2^k * 2
        Arrays.fill(tree, 0);

        //데이터 받고 리프 노드에 채움
        for (int i = 0; i < N; i++) {
            tree[(int) (i+(square))] = Long.parseLong(br.readLine());
        }
        // 부모 노드들 채우기
        for(int i = tree.length-1; i > 1; i = (i-2)) {
            int parent = i/2;
            tree[parent] = tree[i]+tree[i-1];
        }

        //구간합 + 업데이트
        for(int i=0;i<M+K;i++){
            st = new StringTokenizer(br.readLine());
            int a = Integer.parseInt(st.nextToken()); //1: 업데이트, 2: 구간합
            int b = Integer.parseInt(st.nextToken());
            long c = Long.parseLong(st.nextToken());
            if(a==1)
                update(b, c);
            else if(a==2)
                sum(b, (int) c);
        }
    }

    private static void update(int b, long c) {
        long index = b + square - 1;
        long sub = c - tree[(int) index];
        tree[(int) index] = c;
        for(long idx = index/2;idx>0;idx /= 2)
            tree[(int) idx] += sub;
    }

    private static void sum(int b, int c) {
        long result = 0;
        b = (int) (b + square-1);
        c = (int) (c + square-1);
        while(b <= c){
            //시작노드가 오른쪽 노드이면
            if(b%2==1)
                result+= tree[b];
            //끝노드가 왼쪽 노드이면
            if(c%2==0)
                result+= tree[c];

            b = (b+1)/2;
            c = (c-1)/2;
        }
        System.out.println(result);
    }
}