· haskell

Haskell: Closest Pairs Algorithm

As I mentioned in a post a couple of days ago I’ve been writing the closest pairs algorithm in Haskell and while the brute force version works for small numbers of pairs it starts to fall apart as the number of pairs increases:

time ./closest_pairs 100 bf
./closest_pairs 100 bf  0.01s user 0.00s system 87% cpu 0.016 total

time ./closest_pairs 1000 bf
./closest_pairs 1000 bf  3.59s user 0.01s system 99% cpu 3.597 total

time ./closest_pairs 5000 bf
./closest_pairs 5000  554.09s user 0.36s system 99% cpu 9:14.46 total

Luckily there’s a divide and conquer algorithm we can use which brings down the running time from O(n2) to O(n log(n)) which is significantly quicker as n increases in size.

  1. Sort points along the x-coordinate

  2. Split the set of points into two equal-sized subsets by a vertical line x = xmid

  3. Solve the problem recursively in the left and right subsets. This will give the left-side and right-side minimal distances dLmin and dRmin respectively.

  4. Find the minimal distance dLRmin among the pair of points in which one point lies on the left of the dividing vertical and the second point lies to the right (a split pair).

  5. The final answer is the minimum among dLmin, dRmin, and dLRmin.

By step 4 in which we’ll look for a closest pair where the x and y values are in opposite subsets we don’t need to consider the whole list of points again since we already know that the closest pair of points is no further apart than dist = min(dLmin, dRmin).

We can therefore filter out a lot of the values by only keeping values which are within dist of the middle x value.

If a point is further away than that then we already know that the distance between it and any point on the other side will be greater than dist so there’s no point in considering it.

I used the C# version from Rosetta Code as a template and this is what I ended up with:

import Data.Maybe
import Data.List
import Data.List.Split
import Data.Function

dcClosest :: (Ord a, Floating a) => [Point a] -> (Point a, Point a)
dcClosest pairs
  if length pairs <= 3 then = fromJust $ bfClosest pairs
  else
    foldl (\closest (p1:p2:_) -> if distance (p1, p2) < distance closest then (p1, p2) else closest)
          closestPair
          (windowed 2 pairsWithinMinimumDelta)
  where sortedByX = sortBy compare pairs
        (leftByX:rightByX:_) = chunk (length sortedByX `div` 2) sortedByX
        closestPair = if distance closestLeftPair < distance closestRightPair then closestLeftPair else closestRightPair
          where closestLeftPair =  dcClosest leftByX
                closestRightPair = dcClosest rightByX
        pairsWithinMinimumDelta = sortBy (compare `on` snd) $ filter withinMinimumDelta sortedByX
          where withinMinimumDelta (x, _) = abs (xMidPoint - x) <= distance closestPair
                  where (xMidPoint, _) = last leftByX

If the number of pairs is 3 or less then we can just use the brute force algorithm since the whole splitting the list in half thing doesn’t work so well with a list of less than 4 items.

The main function is a fold which goes over all the pairs of points where we’ve determined that there might be a closest pair with one point on the left hand side of our 'sorted by x list' and the other on the right hand side.

We use the 'windowed' function to pair up the points which in this case does something like this:

> windowed 2 [(0,0), (1,1), (2,2), (1.1, 1.2)]
[[(0.0,0.0),(1.0,1.0)],[(1.0,1.0),(2.0,2.0)],[(2.0,2.0),(1.1,1.2)]]

We pass along the closest pair that we’ve found from pairs of points on the left and right hand sides of the list as our seed value and check whether any of the split pairs are closer together than that one. By the time the fold is finished we will have the closest pair in the list.

The code in the where section is used to split the list into two halves, sort it by x and then work out which side of the list has the closest pair. That’s all done recursively and then the last bit of code works out which values we need to consider for the split pairs part of the algorithm.

I learnt about the 'on' function while writing this algorithm which makes it really easy to pick a function to sort a collection with e.g. 'compare on snd' lets us sort a list in ascending order based on the second value in a tuple.

One problem with the way I’ve written this algorithm is that it places a lot of focus on the split pair bit of the code when actually a big part of why it works is that we’re sorting it into an order that makes it easier to work with and then dividing the problem by 2 each time.

My current thinking is that perhaps I should have had that code in the main body of the function rather than hiding it away in the where section.

It isn’t actually necessary to execute the foldl all the way through since we’ll often know there’s not going to be a split pair before we go through all the pairs. I thought it reads pretty nicely as it is thought and it runs pretty quickly as it is anyway!

Using the same data set as before (and getting the same answers!):

> time ./closest_pairs 1000 dc
./closest_pairs 1000 dc  0.02s user 0.00s system 91% cpu 0.024 total

> time ./closest_pairs 5000 dc
./closest_pairs 5000 dc  0.11s user 0.01s system 97% cpu 0.118 total

I described the code for generating random numbers in an earlier blog post but I’ll include it again to show how the algorithm is wired up:

import Control.Monad.State (State, evalState, get, put)
import System.Random (StdGen, mkStdGen, random)
import System

type R a = State StdGen a
rand :: R Double
rand = do
  gen <- get
  let (r, gen') = random gen
  put gen'
  return r

randPair :: R (Double, Double)
randPair = do
  x <- rand
  y <- rand
  return (x,y)

runRandom :: R a -> Int -> a
runRandom action seed = evalState action $ mkStdGen seed

normals :: R [(Double, Double)]
normals = mapM (\_ -> randPair) $ repeat ()

main = do
	args <- getArgs
	let numberOfPairs = read (head args) :: Int
	if length args > 1 && args !! 1 == "bf" then
		putStrLn $ show ( (bfClosest $ take numberOfPairs $ runRandom normals 42))
	else
		putStrLn $ show ( (dcClosest $ take numberOfPairs $ runRandom normals 42))

The full code is on hpaste if anyone has any suggestions for how to improve it.

  • LinkedIn
  • Tumblr
  • Reddit
  • Google+
  • Pinterest
  • Pocket