Mark Needham

Thoughts on Software Development

Archive for the ‘algo-class’ tag

Prim’s algorithm using a heap/priority queue in Ruby

with 2 comments

I recently wrote a blog post describing my implementation of Prim’s Algorithm for the Algorithms 2 class and while it comes up with the right answer for the supplied data set it takes almost 30 seconds to do so!

In one of the lectures Tim Roughgarden points out that we’re doing the same calculations multiple times to work out the next smallest edge to include in our minimal spanning tree and could use a heap to speed things up.

A heap works well in this situation because one of the reasons we might use a heap is to speed up repeated minimum computations i.e. working out the minimum weighted edge to add to our spanning tree.

The pseudocode for the Prim’s algorithm which uses a heap reads like this:

  • Let X = nodes covered so far, V = all the nodes in the graph, E = all the edges in the graph
  • Pick an arbitrary initial node s and put that into X
  • for vVX
    • key[v] = cheapest edge (u,v) with vX
  • while XV:
    • let v = extract-min(heap) (i.e. v is the node which has the minimal edge cost into X)
    • Add v to X
    • for each edge v, wE
      • if w ∈ VX (i.e. w is a node which hasn’t yet been covered)
        • Delete w from heap
        • recompute key[w] = min(key[w], weight(v, w)) (key[w] would only change if the weight of the edge (v,w) is less than the current weight for that key).
        • reinsert w into the heap

We store the uncovered nodes in the heap and set their priority to be the cheapest edge from that node into the set of nodes which we’re already covered.

I came across the PriorityQueue gem which actually seems to be better than a heap because we can have the node as the key and then set the priority of the key to be the edge weight. When you extract the minimum value from the priority queue it makes use of this priority to return the minimum one.

The outline of my solution to this problem looks like this:

MAX_VALUE =  (2**(0.size * 8 -2) -1)
adjacency_matrix = create_adjacency_matrix
@nodes_spanned_so_far, spanning_tree_cost = [1], 0
heap =
nodes_left_to_cover.each do |node|
  cheapest_nodes = get_edges(adjacency_matrix, node-1).
                   select { |_, other_node_index| @nodes_spanned_so_far.include?(other_node_index + 1) } || []
  cheapest = cheapest_nodes.inject([]) do |all_edges, (weight, index)|
    all_edges << { :start => node, :end => index + 1, :weight => weight }
  end.sort { |x,y| x[:weight]  y[:weight] }.first
  weight = !cheapest.nil? ? cheapest[:weight]: MAX_VALUE
  heap[node] = weight
while !nodes_left_to_cover.empty?
  cheapest_node, weight = heap.delete_min
  spanning_tree_cost += weight
  @nodes_spanned_so_far << cheapest_node
  edges_with_potential_change = get_edges(adjacency_matrix, cheapest_node-1).
                                reject { |_, node_index| @nodes_spanned_so_far.include?(node_index + 1) }
  edges_with_potential_change.each do |weight, node_index|
                         [heap.priority(node_index+1), adjacency_matrix[cheapest_node-1][node_index]].min)
puts "total spanning tree cost #{spanning_tree_cost}"

I couldn’t see a way to keep track of the edges that comprise the minimal spanning tree so in this version I’ve created a variable which keeps tracking of the edge weights as we go rather than computing it at the end.

We start off by initialising the priority queue to contain entries for each of the nodes in the graph.

We do this by finding the edges that go from each node to the nodes that we’ve already covered. In this case the only node we’ve covered is node 1 so the priorities for most nodes will be MAX_VALUE and for nodes which have an edge to node 1 it’ll be the weight of that edge.

While we still have nodes left to cover we take the next node with the cheapest weight from the priority queue and add it to the collection of nodes that we’ve covered. We then iterate through the nodes which have an edge to the node we just removed and update the priority queue if necessary.

The time taken for this version of the algorithm to run against the data set was 0.3 seconds as compared to the 29 seconds of the naive implementation.

As usual the code is on github – I need to figure out how to keep track of the edges so if anyone has any suggestions that’d be cool.

Written by Mark Needham

December 15th, 2012 at 4:31 pm

Posted in Algorithms

Tagged with ,

Algo Class: Start simple and build up

without comments

Over the last six weeks I’ve been working through Stanford’s Design and Analysis of Algorithms I class and each week there’s been a programming assignment on a specific algorithm for which a huge data set is provided.

For the first couple of assignments I tried writing the code for the algorithm and then running it directly against the provided data set.

As you might imagine it never worked first time and this approach led to me becoming very frustrated because there’s no way of telling what went wrong.

By the third week I’d adapted and instead tested my code against a much smaller data set to check that the design of the algorithm was roughly correct.

I thought that I would be able to jump straight from the small data set to the huge one but realised that this sometimes didn’t work for the following reasons:

  1. An inefficient algorithm will work fine on a small data set but grind to a halt on a larger data set. e.g. my implementation of the strongly connected components (SCC) graph algorithm did a scan of a 5 million element list millions of times.
  2. An incorrect algorithm may still work on a small data set. e.g. my implementation of SCC didn’t consider that some vertices wouldn’t have forward edges and excluded them.

My colleague Seema showed me a better approach where we still use a small data set but think through all the intricacies of the algorithm and make sure our data set covers all of them.

For the SCC algorithm this meant creating a graph with 15 vertices where some vertices weren’t strongly connected while others were connected to many others.

In my initial data set all of the vertices were strongly connected which meant I had missed some edge cases in the design of the algorithm.

Taking this approach was very effective for ensuring the correctness of the algorithm but it could still be inefficient.

I used Visual VM to identify where the performance problems were.

In one case I ended up running out of memory because I had 3 different representations of the graph and was inadvertently using all of them.

I made the stupid mistake of not writing any automated tests for the smaller data sets. They would have been very helpful for ensuring I didn’t break the algorithm when performance tuning it.

I should really have learnt that lesson by now given that I’ve been writing code in a test driven style for nearly six years but apparently I’d turned off that part of my brain.

Looking forward to Design and Analysis of Algorithms II!

Written by Mark Needham

April 24th, 2012 at 7:17 am

Scala: Counting number of inversions (via merge sort) for an unsorted collection

with 2 comments

The first programming questions of algo-class requires you to calculate the number of inversions it would take using merge sort to sort a collection in ascending order.

I found quite a nice explanation here too:

Finding “similarity” between two rankings. Given a sequence of n numbers 1..n (assume all numbers are distinct). Define a measure that tells us how far this list is from being in ascending order. The value should be 0 if a_1 < a_2 < ... < a_n and should be higher as the list is more "out of order". e.g. 2 4 1 3 5 1 2 3 4 5 The sequence 2, 4, 1, 3, 5 has three inversions (2,1), (4,1), (4,3).

The simple/naive way of solving this problem is to iterate through the collection in two loops and compare each value and its current index with the others, looking for ones which are not in the right order.

I wrote the following Ruby code which seems to do the job:

def number_of_inversions(arr)
  count = 0
  arr.each_with_index do |item_a, index_a|
    arr.each_with_index do |item_b, index_b|
      if index_b >= index_a && item_a > item_b
        count +=1
>> number_of_inversions [6,5,4,3,2,1]
=> 15

That works fine for small collections but since we’re effectively doing nested for loops it actually takes O(n²) time which means that it pretty much kills my machine on the sample data set of 100,000 numbers.

An O(n log(n)) solution is explained in the lectures which involves recursively splitting the collection in half (like in merge sort) and then counting how many elements remain in the left collection when we select an item from the right collection since this will tell us how many positions/inversions out of place that element is.

e.g. if we’re trying to work out how many inversions in the collection [1,3,5,2,4,6] we eventually end up merging [1,3,5] and [2,4,6]

  • Our first iteration puts ‘1’ from the left collection into the new array.
  • The second iteration puts ‘2’ from the right collection into the new array and there are two elements left in the left collection (‘3’ and ‘5’) so we have 2 inversions (3,2 and 5,2).
  • Our third iteration puts ‘3’ from the left collection into the new array.
  • Our fourth iteration puts ‘4’ from the right collection into the new array and there is one element left in the left collection (‘5’) so we have 1 inversion (5,4)
  • Our fifth iteration puts ‘5’ from the left collection and the sixth puts ‘6’ from the right collection in.

The number of inversions for that example is therefore 3.

My colleague Amir and I decided to try and write some Scala code to do this and ended up with the following adapted merge sort:

import io.Source
object MSort {
  def main(args:Array[String]) {
    val fileWithNumbers = "/Users/mneedham/Documents/algo-class/IntegerArray.txt"
    val inversions: BigInt = numberOfInversions(Source.fromFile(new
  def numberOfInversions(collection: List[Int]): BigInt = {
    var count: BigInt = 0
    def inversionsInner(innerCollection: List[Int]): List[Int] = {
      def merge(left: List[Int], right: List[Int]): Stream[Int] = (left, right) match {
        case (x :: xs, y :: ys) if x < y=> { Stream.cons(x, merge(xs, right)) }
        case (x :: xs, y :: ys) => { count = count + left.length; Stream.cons(y, merge(left, ys)) }
        case _ => if (left.isEmpty) right.toStream else left.toStream
      val n = innerCollection.length / 2
      if (n == 0) innerCollection
      else {
        val (lowerHalf, upperHalf) = innerCollection splitAt n
        merge(inversionsInner(lowerHalf), inversionsInner(upperHalf)).toList

The interesting line is number 15 where we are going to select a value from the right collection and therefore increment our count by the number of items left in the left collection.

It works but it’s a bit annoying that we had to use a ‘var’ to keep track of the count since that’s not really idiomatic Scala.

I’ve been trying to find a way of passing the count around as an accumulator and returning it at the end but really struggled to get the code to compile when I started doing that and seemed to make the code a lot more complicated than it is now.

I’m sure there is a way to do it but I can’t see it at the moment!

Since the mutation is reasonably well encapsulated I’m not sure whether it really matters that much but it’s always interesting an interesting exercise to try and write code which doesn’t mutate state so if anyone can see a good way to do it let me know.

Written by Mark Needham

March 20th, 2012 at 6:53 am

Posted in Scala

Tagged with ,