최소 신장 트리(MST)

Minimum Spanning Tree : 최소 신장 트리

  • 사용된 간선들의 가중치 합이 최소인 트리를 뜻합니다.
  • 그래프 내의 모든 정점을 포함하는 트리입니다.
  • 최소 신장 트리라고 하며, 그래프의 최소 연결 부분 그래프입니다.
    • 간선의 수가 가장 적음
    • n개의 정점을 가지는 그래프는 (n-1)개의 간선으로 연결 됨
  • 그래프에서 일부 간선을 선택해서 만드는 트리입니다.

최소 신장 트리의 특징

  • DFS, BFS를 사용하여 그래프에서 신장 트리를 탐색할 수 있습니다.
  • 하나의 그래프에는 많은 신장 트리가 존재하며, 그 중 가장 작은 가중치의 트리가 최소 신장 트리입니다.
  • 사이클이 포함되어서는 안됩니다.
  • 이전의 신장 트리와는 상관없이 무조건 최소 간선만을 선택합니다.

최소 신장 트리의 사용 사례

  • 도로 건설, 전기 회로, 통신, 배관 등 최소의 길이로 연결하기 위한 문제를 해결할 때 사용합니다.

MST 구현 방법

1. Prim MST 알고리즘

  • 시작 정점부터 출발하여 신장트리 집합을 단계적으로 확장 해나가는 방법
    • 정점 선택을 기반으로 하며, 이전 단계에서 만들어진 신장 트리를 확장합니다.
    • 그리디 알고리즘의 일종입니다.

  • 구현 방법
    1. 임의의 시작 정점을 하나 정합니다.
    2. 시작 정점만 포함된 신장 트리 집합을 만듭니다.
    3. n개의 정점이 모두 선택될 때까지 반복합니다.
      1. 신장 트리 집합에 포함된 정점에 인접했으며, 아직 방문하지 않은 정점 중 최소 비용 간선을 선택합니다.
      2. 선택한 정점을 신장 트리 집합에 표시합니다.
from math import inf
def prim(start):
    global n, adj_mat
    # 현재 방문한 정점들의 집합
    visited_set = set()
    visited_set.add(start)
    distance = 0

    # n-1 개의 간선을 선택할 때까지 반복
    for _ in range(n - 1):
        # min Dist : 현재 방문한 정점에서 갈 수 있는 간선의 최단 거리
        # next Node : 현재 방문한 정점에서 최단 거리로 갈 수 있는 정점
        min_dist, next_node = inf, -1

        # 방문한 모든 정점 확ㅇ니
        for node in visited_set:
            # 해당 정점과 연결되어 있고 아직 방문하지 않은 정점 중 소모 비용이 적은 정점 탐색
            for j in range(1, n + 1):
                if j not in visited_set and 0 < adj_mat[node][j] < min_dist:
                    min_dist = adj_mat[node][j]
                    next_node = j
        distance += min_dist
        visited_set.add(next_node)

    return distance

n,m = map(int, input().split())
adj_mat = [[0] * (n + 1) for _ in range(n + 1)]
for _ in range(m):
    x, y, value = map(int, input().split())
    adj_mat[x][y] = value
    adj_mat[y][x] = value
print(prim(1))

2. Kruskal MST 알고리즘

  • 비용에 따라 정렬 된 간선을 하나씩 선택하는 방법
  • 프림 알고리즘과 마참가지로, 그리드 알고리즘의 일종입니다.
    • 그래프의 간선들의 가중치를 기준으로 오름차순 정렬
    • n-1개의 간선이 선택될 때까지 반복
      • 가중치가 낮은 간선부터 사이클을 형성하지 않도록 선택
      • 사이클 형성 여부는 union-find 알고리즘을 활용

  • 구현 방법
import sys
input = sys.stdin.readline

v, e = map(int, input().split())
edges = []
for _ in range(e):
    a, b, c = map(int, input().split())
    edges.append((a, b, c))
edges.sort(key=lambda x: x[2])  # 간선이 작은 순
parent = [i for i in range(v + 1)] #부모를 기준으로 union

def find(x):
    if parent[x] != x: # 자기 자신이 루트 노드가 아니라면
        parent[x] = find(parent[x]) # 부모 노드를 찾아서 갱신
    return parent[x] # 자기 자신이 루트 노드라면 그대로 반환

# x가 속해있는 집합과 y가 속해있는 집합 합치기
def union(x, y):
    x = find(x)  # x의 루트
    y = find(y)  # y의 루트

    # 서로 루트가 다르면 주 집합 합치기
    if x < y:
        parent[y] = x
    else:
        parent[x] = y
answer = 0
for v,e,h in edges:
    if find(v) != find(e):
        union(v,e)
        answer+=h
print(answer)

참고