The Algorithms logo
The Algorithms
AboutDonate

Kruskal's Algorithm

using System;
using System.Collections.Generic;
using DataStructures.DisjointSet;

namespace Algorithms.Graph.MinimumSpanningTree
{
    /// <summary>
    ///     Algorithm to determine the minimum spanning forest of an undirected graph.
    /// </summary>
    /// <remarks>
    ///     Kruskal's algorithm is a greedy algorithm that can determine the
    ///     minimum spanning tree or minimum spanning forest of any undirected
    ///     graph. Unlike Prim's algorithm, Kruskal's algorithm will work on
    ///     graphs that are unconnected. This algorithm will always have a
    ///     running time of O(E log V) where E is the number of edges and V is
    ///     the number of vertices/nodes.
    ///     More information: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm .
    ///     Pseudocode and analysis: https://www.personal.kent.edu/~rmuhamma/Algorithms/MyAlgorithms/GraphAlgor/primAlgor.htm .
    /// </remarks>
    public static class Kruskal
    {
        /// <summary>
        ///     Determine the minimum spanning tree/forest of the given graph.
        /// </summary>
        /// <param name="adjacencyMatrix">Adjacency matrix representing the graph.</param>
        /// <returns>Adjacency matrix of the minimum spanning tree/forest.</returns>
        public static float[,] Solve(float[,] adjacencyMatrix)
        {
            ValidateGraph(adjacencyMatrix);

            var numNodes = adjacencyMatrix.GetLength(0);
            var set = new DisjointSet<int>();
            var nodes = new Node<int>[numNodes];
            var edgeWeightList = new List<float>();
            var nodeConnectList = new List<(int, int)>();

            // Add nodes to disjoint set
            for (var i = 0; i < numNodes; i++)
            {
                nodes[i] = set.MakeSet(i);
            }

            // Create lists with edge weights and associated connectivity
            for (var i = 0; i < numNodes - 1; i++)
            {
                for (var j = i + 1; j < numNodes; j++)
                {
                    if (float.IsFinite(adjacencyMatrix[i, j]))
                    {
                        edgeWeightList.Add(adjacencyMatrix[i, j]);
                        nodeConnectList.Add((i, j));
                    }
                }
            }

            var edges = Solve(set, nodes, edgeWeightList.ToArray(), nodeConnectList.ToArray());

            // Initialize minimum spanning tree
            var mst = new float[numNodes, numNodes];
            for (var i = 0; i < numNodes; i++)
            {
                mst[i, i] = float.PositiveInfinity;

                for (var j = i + 1; j < numNodes; j++)
                {
                    mst[i, j] = float.PositiveInfinity;
                    mst[j, i] = float.PositiveInfinity;
                }
            }

            foreach (var (node1, node2) in edges)
            {
                mst[node1, node2] = adjacencyMatrix[node1, node2];
                mst[node2, node1] = adjacencyMatrix[node1, node2];
            }

            return mst;
        }

        /// <summary>
        ///     Determine the minimum spanning tree/forest of the given graph.
        /// </summary>
        /// <param name="adjacencyList">Adjacency list representing the graph.</param>
        /// <returns>Adjacency list of the minimum spanning tree/forest.</returns>
        public static Dictionary<int, float>[] Solve(Dictionary<int, float>[] adjacencyList)
        {
            ValidateGraph(adjacencyList);

            var numNodes = adjacencyList.Length;
            var set = new DisjointSet<int>();
            var nodes = new Node<int>[numNodes];
            var edgeWeightList = new List<float>();
            var nodeConnectList = new List<(int, int)>();

            // Add nodes to disjoint set and create list of edge weights and associated connectivity
            for (var i = 0; i < numNodes; i++)
            {
                nodes[i] = set.MakeSet(i);

                foreach(var (node, weight) in adjacencyList[i])
                {
                    edgeWeightList.Add(weight);
                    nodeConnectList.Add((i, node));
                }
            }

            var edges = Solve(set, nodes, edgeWeightList.ToArray(), nodeConnectList.ToArray());

            // Create minimum spanning tree
            var mst = new Dictionary<int, float>[numNodes];
            for (var i = 0; i < numNodes; i++)
            {
                mst[i] = new Dictionary<int, float>();
            }

            foreach (var (node1, node2) in edges)
            {
                mst[node1].Add(node2, adjacencyList[node1][node2]);
                mst[node2].Add(node1, adjacencyList[node1][node2]);
            }

            return mst;
        }

        /// <summary>
        ///     Ensure that the given graph is undirected.
        /// </summary>
        /// <param name="adj">Adjacency matrix of graph to check.</param>
        private static void ValidateGraph(float[,] adj)
        {
            if (adj.GetLength(0) != adj.GetLength(1))
            {
                throw new ArgumentException("Matrix must be square!");
            }

            for (var i = 0; i < adj.GetLength(0) - 1; i++)
            {
                for (var j = i + 1; j < adj.GetLength(1); j++)
                {
                    if (Math.Abs(adj[i, j] - adj[j, i]) > 1e-6)
                    {
                        throw new ArgumentException("Matrix must be symmetric!");
                    }
                }
            }
        }

        /// <summary>
        ///     Ensure that the given graph is undirected.
        /// </summary>
        /// <param name="adj">Adjacency list of graph to check.</param>
        private static void ValidateGraph(Dictionary<int, float>[] adj)
        {
            for (var i = 0; i < adj.Length; i++)
            {
                foreach (var edge in adj[i])
                {
                    if (!adj[edge.Key].ContainsKey(i) || Math.Abs(edge.Value - adj[edge.Key][i]) > 1e-6)
                    {
                        throw new ArgumentException("Graph must be undirected!");
                    }
                }
            }
        }

        /// <summary>
        ///     Determine the minimum spanning tree/forest.
        /// </summary>
        /// <param name="set">Disjoint set needed for set operations.</param>
        /// <param name="nodes">List of nodes in disjoint set associated with each node.</param>
        /// <param name="edgeWeights">Weights of each edge.</param>
        /// <param name="connections">Nodes associated with each item in the <paramref name="edgeWeights"/> parameter.</param>
        /// <returns>Array of edges in the minimum spanning tree/forest.</returns>
        private static (int, int)[] Solve(DisjointSet<int> set, Node<int>[] nodes, float[] edgeWeights, (int, int)[] connections)
        {
            var edges = new List<(int, int)>();

            Array.Sort(edgeWeights, connections);

            foreach (var (node1, node2) in connections)
            {
                if (set.FindSet(nodes[node1]) != set.FindSet(nodes[node2]))
                {
                    set.UnionSet(nodes[node1], nodes[node2]);
                    edges.Add((node1, node2));
                }
            }

            return edges.ToArray();
        }
    }
}