본문 바로가기
알고리즘/개념

최소 신장 트리 (MST)

by 도쿠니 2022. 5. 11.

최소 신장 트리 (MST, Minimum Spanning Tree) 란?

  • 그래프 상의 모든 노드들을 최소 비용으로 연결하는 방법

 

종류

  • 크루스칼 (Kruskal)
    • 간선 중 최소 값을 가진 간선부터 순차적으로 연결
      • 사이클 발생 시 다른 간선을 선택
        • 사이클 발생 여부는 Union-Find를 통해서 체크
    • 주로 간선 수가 적을 때 사용
    • 시간복잡도 : O(ElogE)

출처 : zerobase

import java.util.Arrays;
import java.util.Comparator;

public class Kruskal {

    // Union-Find를 위한 배열
    static int[] parents;


    public static int kruskal(int[][] data, int v, int e) {
        int weightSum = 0;

        // 간선 중 최소값부터 순차적으로 연결해야하기 때문에 정렬이 필요
        Arrays.sort(data, Comparator.comparingInt(o -> o[2]));
        
        parents = new int[v + 1];

        // 부모 배열을 자기 자신의 노드 번호로 초기화
        for (int i = 1; i < v + 1; i++) {
            parents[i] = i;
        }

        // 간선 수 만큼 반복하여 실행
        for (int i = 0; i < e; i++) {
            // 서로 부모가 다르면 (연결되지 않은 노드라면)
            if (find(data[i][0]) != find(data[i][1])) {
                union(data[i][0], data[i][1]);
                weightSum += data[i][2];
            }
        }

        return weightSum;
    }

    // 두 개의 노드를 같은 집합으로 묶어주는 메소드
    public static void union(int a, int b) {
        int aParent = find(a);
        int bParent = find(b);

        if (aParent != bParent) {
            parents[aParent] = bParent;
        }
    }

    // 노드가 어느 노드에 연결되어있는지 찾아주는 메소드
    public static int find(int a) {
        if (a == parents[a]) {
            return a;
        }
        return parents[a] = find(parents[a]);
    }

    public static void main(String[] args) {
        //Test
        int v = 7;
        int e = 10;

        int[][] graph = {{1,3,1},{1,2,9},{1,6,8},{2,4,13},{2,5,2},{2,6,7},{3,4,12},{4,7,17},{5,6,5},{5,7,20}};

        System.out.println(kruskal(graph, v, e));
    }

}

 

  • 프림 (Prim)
    • 임의의 노드에서 시작하여 연결된 노드들의 간선 중 가장 낮은 가중치를 갖는 간선을 선택
      • 이미 방문했던 노드는 방문하지 않음 -> 방문을 체크할 배열 필요
      • 우선순위 큐 이용하여 가장 낮은 가중치의 간선을 뽑아온다.
    • 간선의 개수가 많을 때 크루스칼 보다 유리
    • 시간복잡도 : O(ElogV)

import java.util.ArrayList;
import java.util.Comparator;
import java.util.PriorityQueue;

public class Prim {

    static class Node {
        int to;
        int weight;

        public Node(int to, int weight) {
            this.to = to;
            this.weight = weight;
        }
    }

    public static int prim(int[][] data, int v, int e) {
        int weightSum = 0;

        ArrayList<ArrayList<Node>> graph = new ArrayList<>();

        for (int i = 0; i < v + 1; i++) {
            graph.add(new ArrayList<>());
        }

        for (int i = 0; i < e; i++) {
            graph.get(data[i][0]).add(new Node(data[i][1], data[i][2]));
            graph.get(data[i][1]).add(new Node(data[i][0], data[i][2]));
        }

        boolean[] visited = new boolean[v + 1];

        PriorityQueue<Node> pq = new PriorityQueue<>(Comparator.comparingInt(x -> x.weight));

        pq.add(new Node(1, 0));

        int cnt = 0;
        while (!pq.isEmpty()) {
            Node cur = pq.poll();

            cnt += 1;

            if (visited[cur.to]) {
                continue;
            }

            visited[cur.to] = true;

            weightSum += cur.weight;

            // 연결된 간선 수가 노드 수 - 1 만큼이면 다 연결되었다는 소리이니 종료
            if (cnt == v) {
                return weightSum;
            }

            for (int i = 0; i < graph.get(cur.to).size(); i++) {
                Node adjNode = graph.get(cur.to).get(i);
                if (visited[adjNode.to]) {
                    continue;
                }

                pq.offer(adjNode);
            }
        }


        return weightSum;
    }

    public static void main(String[] args) {
        //Test
        int v = 7;
        int e = 10;

        int[][] graph = {{1, 3, 1}, {1, 2, 9}, {1, 6, 8}, {2, 4, 13}, {2, 5, 2}, {2, 6, 7}, {3, 4, 12}, {4, 7, 17}, {5, 6, 5}, {5, 7, 20}};

        System.out.println(prim(graph, v, e));
    }
}

댓글