Minimum Spanning Trees: Kruskal's Algorithm

Set Up

# python
from __future__ import annotations
from pprint import pprint

# pypi
from expects import be, be_true, contain_only, equal, expect

A Vertex

class Vertex:
    """A node in the graph

    Args:
     identifier: something to distinguish the node
     parent: The node pointing to this node
     rank: The height of the node
    """
    def __init__(self, identifier, parent: Vertex=None, rank: int=0) -> None:
        self.identifier = identifier
        self.parent = parent
        self.rank = rank
        return

    @property
    def is_root(self) -> bool:
        """Checks if this vertex is the root of a tree

        Returns:
         true if this is a root
        """
        return self.parent is self

    def __repr__(self) -> str:
        return f"{self.identifier}"

Edges

class Edge:
    """A weighted edge between two nodes

    Args:
     node_1: vertex at one end
     node_2: vertex at the other end
     weight: weight of the edge
    """
    def __init__(self, node_1: Vertex, node_2: Vertex, weight: float) -> None:
        self.node_1 = node_1
        self.node_2 = node_2
        self.weight = weight
        return

    def __repr__(self) -> str:
        return f"{self.node_1} --{self.weight}-- {self.node_2}"

A Graph

class Graph:
    """An Undirected Graph
    """
    def __init__(self) -> None:
        self.vertices = set()
        self.edges = set()
        return

    def add_edge(self, node_1: Vertex, node_2: Vertex, weight: float) -> None:
        """Add the nodes and an edge between them

        Note:
         although the graph is undirected, the (node_2, node_1) edge 
         isn't added, it's assumed that the order of the arguments 
         doesn't matter
        """
        self.vertices.add(node_1)
        self.vertices.add(node_2)
        self.edges.add(Edge(node_1, node_2, weight))
        return

Disjoint Sets

class Disjoint:
    """Methods to treat the tree as a set"""
    @classmethod
    def make_sets(cls, vertices: set) -> None:
        """Initializes the vertices as trees in a forest

        Args:
         vertices: collection of nodes to set up
        """
        for vertex in vertices:
            cls.make_set(vertex)
        return

    @classmethod
    def make_set(cls, vertex: Vertex) -> None:
        """Initialize the values

        Args:
         vertex: node to set up
        """
        vertex.parent = vertex
        vertex.rank = 0
        return

    @classmethod
    def find_set(cls, vertex: Vertex) -> Vertex:
        """Find the root of the set that vertex belongs to

        Args:
         vertex: member of set to find

        Returns:
         root of set that vertex belongs to
        """
        if not vertex.is_root:
            vertex.parent = cls.find_set(vertex.parent)            
        return vertex.parent

    @classmethod
    def union(cls, vertex_1: Vertex, vertex_2: Vertex) -> None:
        """merge the trees that the vertices belong to

        Args:
         vertex_1: member of first tree
         vertex_2: member of second tree
        """
        cls.link(cls.find_set(vertex_1), cls.find_set(vertex_2))
        return

    @classmethod
    def link(cls, root_1: Vertex, root_2: Vertex) -> None:
        """make lower-ranked tree root a child of higher-ranked

        Args:
         root_1: root of a tree
         root_2: root of a different tree
        """
        if root_1.rank > root_2.rank:
            root_2.parent = root_1
        else:
            root_1.parent = root_2
            if root_1.rank == root_2.rank:
                root_2.rank += 1
        return

Kruskal's Algorithm

def kruskal(graph: Graph) -> set:
    """Create a Minimum Spanning Tree out of Vertices

    Args:
     graph: the graph from which we create the MST

    Returns:
     set of edges making up the minimum spanning tree
    """
    spanning_tree = set()
    Disjoint.make_sets(graph.vertices)
    edges = sorted(graph.edges, key=lambda edge: edge.weight)
    for edge in edges:
        tree_1 = Disjoint.find_set(edge.node_1)
        tree_2 = Disjoint.find_set(edge.node_2)
        if (tree_1 is not tree_2):
            spanning_tree.add(edge)
            Disjoint.union(edge.node_1, edge.node_2)
    return spanning_tree

Try It Out

nodes = dict()

for identifier in "abcdefghi":
    nodes[identifier] = Vertex(identifier)
graph = Graph()
graph.add_edge(nodes["a"], nodes["b"], 4)
graph.add_edge(nodes["a"], nodes["h"], 8)
graph.add_edge(nodes["b"], nodes["h"], 11)
graph.add_edge(nodes["b"], nodes["c"], 8)
graph.add_edge(nodes["c"], nodes["d"], 7)
graph.add_edge(nodes["c"], nodes["i"], 2)
graph.add_edge(nodes["c"], nodes["f"], 4)
graph.add_edge(nodes["d"], nodes["e"], 9)
graph.add_edge(nodes["d"], nodes["f"], 14)
graph.add_edge(nodes["e"], nodes["f"], 10)
graph.add_edge(nodes["f"], nodes["g"], 2)
graph.add_edge(nodes["g"], nodes["h"], 1)
graph.add_edge(nodes["g"], nodes["i"], 6)
graph.add_edge(nodes["h"], nodes["i"], 7)

Disjoint.make_sets(graph.vertices)

expect(graph.vertices).to(contain_only(*nodes.values()))

for node in nodes.values():
    expect(node.parent).to(be(node))
    expect(node.rank).to(equal(0))

tree = kruskal(graph)
pprint(tree)
{a --4-- b,
 a --8-- h,
 c --7-- d,
 c --2-- i,
 c --4-- f,
 d --9-- e,
 f --2-- g,
 g --1-- h}

Note: I originally thought I could check the tree structure, but the disjoint-set methods use Path Compression, so all the nodes end up having the same parent (in this case node "h").