## Python, graph, minimum_spanning_tree_kruskal2.py

``````from __future__ import annotations

from typing import Generic, TypeVar

T = TypeVar("T")

class DisjointSetTreeNode(Generic[T]):
# Disjoint Set Node to store the parent and rank
def __init__(self, data: T) -> None:
self.data = data
self.parent = self
self.rank = 0

class DisjointSetTree(Generic[T]):
# Disjoint Set DataStructure
def __init__(self) -> None:
# map from node name to the node object
self.map: dict[T, DisjointSetTreeNode[T]] = {}

def make_set(self, data: T) -> None:
# create a new set with x as its member
self.map[data] = DisjointSetTreeNode(data)

def find_set(self, data: T) -> DisjointSetTreeNode[T]:
# find the set x belongs to (with path-compression)
elem_ref = self.map[data]
if elem_ref != elem_ref.parent:
elem_ref.parent = self.find_set(elem_ref.parent.data)
return elem_ref.parent

self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
) -> None:
# helper function for union operation
if node1.rank > node2.rank:
node2.parent = node1
else:
node1.parent = node2
if node1.rank == node2.rank:
node2.rank += 1

def union(self, data1: T, data2: T) -> None:
# merge 2 disjoint sets

class GraphUndirectedWeighted(Generic[T]):
def __init__(self) -> None:
# connections: map from the node to the neighbouring nodes (with weights)
self.connections: dict[T, dict[T, int]] = {}

def add_node(self, node: T) -> None:
# add a node ONLY if its not present in the graph
if node not in self.connections:
self.connections[node] = {}

def add_edge(self, node1: T, node2: T, weight: int) -> None:
# add an edge with the given weight
self.connections[node1][node2] = weight
self.connections[node2][node1] = weight

def kruskal(self) -> GraphUndirectedWeighted[T]:
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
"""
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm

Example:
>>> g1 = GraphUndirectedWeighted[int]()
>>> g1.add_edge(3, 5, 100) # Removed in MST
>>> assert 5 in g1.connections[3]
>>> mst = g1.kruskal()
>>> assert 5 not in mst.connections[3]

>>> g2 = GraphUndirectedWeighted[str]()
>>> g2.add_edge('C', 'E', 100) # Removed in MST
>>> assert 'E' in g2.connections["C"]
>>> mst = g2.kruskal()
>>> assert 'E' not in mst.connections['C']
"""

# getting the edges in ascending order of weights
edges = []
seen = set()
for start in self.connections:
for end in self.connections[start]:
if (start, end) not in seen:
edges.append((start, end, self.connections[start][end]))
edges.sort(key=lambda x: x[2])

# creating the disjoint set
disjoint_set = DisjointSetTree[T]()
for node in self.connections:
disjoint_set.make_set(node)

# MST generation
num_edges = 0
index = 0
graph = GraphUndirectedWeighted[T]()
while num_edges < len(self.connections) - 1:
u, v, w = edges[index]
index += 1
parent_u = disjoint_set.find_set(u)
parent_v = disjoint_set.find_set(v)
if parent_u != parent_v:
num_edges += 1