Prim’s (MST): Special Subtree, Explained

Prim’s algorithm (Minimum Spanning Tree) explained in Java

Artemis
Geek Culture

--

Photo by AltumCode on Unsplash

Prim’s algorithm is classic. You can learn its history and algorithm from the wiki link. But it is not straightforward to understand and implement. Here will implement it in Java for easy understanding.

Prim’s algorithm can start from any vertex. We can start a specific vertex to find the minimum total distance for all vertices. There are two common approaches: 1. expanding from the start with the min distance one at a time or 2. exploring from the unvisited vertexes one at a time. So we can see a min-heap is a good solution to construct the edges. Here uses HackerRank hard problem to illustrate it.

Problem

Given a graph that consists of several edges connecting its nodes, find a subgraph of the given graph with the following properties:

  • The subgraph contains all the nodes present in the original graph.
  • The subgraph is of minimum overall weight (sum of all edges) among all such subgraphs.
  • It is also required that there is exactly one exclusive path between any two nodes of the subgraph.

One specific node S is fixed as the starting point of finding the subgraph using Prim’s Algorithm.
Find the total weight or the sum of all edges in the subgraph.

Please see the original question below:

Solution

As mentioned above, we can use a min heap to implement two approaches. Here will illustrate the second approach.

import java.io.*;
import java.math.*;
import java.security.*;
import java.text.*;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.*;
import java.util.regex.*;
import java.util.stream.*;
import static java.util.stream.Collectors.joining;
import static java.util.stream.Collectors.toList;

class Result {

/*
* Complete the 'prims' function below.
*
* The function is expected to return an INTEGER.
* The function accepts following parameters:
* 1. INTEGER n
* 2. 2D_INTEGER_ARRAY edges
* 3. INTEGER start
*/

public static int prims(int n, List<List<Integer>> edges, int start) {

PriorityQueue<List<Integer>> heap = new PriorityQueue<>(
(a,b) -> a.get(2) - b.get(2));
heap.addAll(edges);

Set<Integer> visited = new HashSet<>();
visited.add(start);
int minSum = 0;

while(visited.size() < n) {
List<List<Integer>> temp = new LinkedList<>();
while(!heap.isEmpty()) {
List<Integer> min = heap.poll();
if (visited.contains(min.get(0)) && visited.contains(min.get(1))) {
continue;
} else if (visited.contains(min.get(0)) && !visited.contains(min.get(1))) {
visited.add(min.get(1));
minSum += min.get(2);
break;
} else if (visited.contains(min.get(1)) && !visited.contains(min.get(0))) {
visited.add(min.get(0));
minSum += min.get(2);
break;
} else {
temp.add(min);
}
}
heap.addAll(temp);
}

return minSum;
}
}

public class Solution {
public static void main(String[] args) throws IOException {
BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(System.in));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(System.getenv("OUTPUT_PATH")));

String[] firstMultipleInput = bufferedReader.readLine().replaceAll("\\s+$", "").split(" ");

int n = Integer.parseInt(firstMultipleInput[0]);

int m = Integer.parseInt(firstMultipleInput[1]);

List<List<Integer>> edges = new ArrayList<>();

IntStream.range(0, m).forEach(i -> {
try {
edges.add(
Stream.of(bufferedReader.readLine().replaceAll("\\s+$", "").split(" "))
.map(Integer::parseInt)
.collect(toList())
);
} catch (IOException ex) {
throw new RuntimeException(ex);
}
});

int start = Integer.parseInt(bufferedReader.readLine().trim());

int result = Result.prims(n, edges, start);

bufferedWriter.write(String.valueOf(result));
bufferedWriter.newLine();

bufferedReader.close();
bufferedWriter.close();
}
}

It used Java PriorityQueue as a min-heap. It 100% passed all HackerRank test cases as follows:

Happy Coding!

Questions, ideas? Leave comments here. Follow me to be part of the fun problem-solving journey.

--

--