Mark Needham

Thoughts on Software Development

Kaggle Digit Recognizer: Mahout Random Forest attempt

with 7 comments

I’ve written previously about the K-means approach that Jen and I took when trying to solve Kaggle’s Digit Recognizer and having stalled at about 80% accuracy we decided to try one of the algorithms suggested in the tutorials section – the random forest!

We initially used a clojure random forests library but struggled to build the random forest from the training set data in a reasonable amount of time so we switched to Mahout’s version which is based on Leo Breiman’s random forests paper.

There’s a really good example explaining how ensembles work on the Factual blog which we found quite useful in helping us understand how random forests are supposed to work.

One of the most powerful Machine Learning techniques we turn to is ensembling. Ensemble methods build surprisingly strong models out of a collection of weak models called base learners, and typically require far less tuning when compared to models like Support Vector Machines.

Most ensemble methods use decision trees as base learners and many ensembling techniques, like Random Forests and Adaboost, are specific to tree ensembles.

We were able to adapt the BreimanExample included in the examples section of the Mahout repository to do what we wanted.

To start with we wrote the following code to build the random forest:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    String descriptor = "L N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N N ";
    String[] trainDataValues = fileAsStringArray("data/train.csv");
 
    Data data = DataLoader.loadData(DataLoader.generateDataset(descriptor, false, trainDataValues), trainDataValues);
 
    int numberOfTrees = 100;
    DecisionForest forest = buildForest(numberOfTrees, data);
  }
 
  private static DecisionForest buildForest(int numberOfTrees, Data data) {
    int m = (int) Math.floor(Maths.log(2, data.getDataset().nbAttributes()) + 1);
 
    DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
    treeBuilder.setM(m);
 
    return new SequentialBuilder(RandomUtils.getRandom(), treeBuilder, data.clone()).build(numberOfTrees);
  }
 
  private static String[] fileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add(strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

The training data file looks a bit like this:

label,pixel0,pixel1,pixel2,pixel3,pixel4,pixel5,pixel6,pixel7,pixel8...,pixel783
1,0,0,0,0,0,0,...,0
0,0,0,0,0,0,0,...,0

So in this case the label is in the first column which is represented as an L in the descriptor and the next 784 columns are the numerical value of the pixels in the image (hence the 784 N‘s in the descriptor).

We’re telling it to create a random forest which contains 100 trees and since we have a finite number of categories that an entry can be classified as we pass false as the 2nd argument (regression) of DataLoader.generateDataSet.

The m value determines how many attributes (pixel values in this case) are used to construct each tree and supposedly log2(number_of_attributes) + 1 is the optimal value for that!

We then wrote the following code to predict the labels of the test data set:

public class MahoutKaggleDigitRecognizer {
  public static void main(String[] args) throws Exception {
    ...
    String[] testDataValues = testFileAsStringArray("data/test.csv");
    Data test = DataLoader.loadData(data.getDataset(), testDataValues);
    Random rng = RandomUtils.getRandom();
 
    for (int i = 0; i < test.size(); i++) {
    Instance oneSample = test.get(i);
 
    double classify = forest.classify(test.getDataset(), rng, oneSample);
    int label = data.getDataset().valueOf(0, String.valueOf((int) classify));
 
    System.out.println("Label: " + label);
  }
 
  private static String[] testFileAsStringArray(String file) throws Exception {
    ArrayList<String> list = new ArrayList<String>();
 
    DataInputStream in = new DataInputStream(new FileInputStream(file));
    BufferedReader br = new BufferedReader(new InputStreamReader(in));
 
    String strLine;
    br.readLine(); // discard top one (header)
    while ((strLine = br.readLine()) != null) {
      list.add("-," + strLine);
    }
 
    in.close();
    return list.toArray(new String[list.size()]);
  }
}

There were a couple of things that we found confusing when working out how to do this:

  1. The format of the test data needs to be identical to that of the training data which consisted of a label followed by 784 numerical values. Obviously with the test data we don’t have a label so Mahout excepts us to pass a ‘-’ where the label would go otherwise it will throw an exception, which explains the ‘-’ on the list.add line.
  2. We initially thought the value returned by forest.classify was the prediction but in actual fact it’s an index which we then need to look up on the data set.

When we ran this algorithm against the test data set with 10 trees we got an accuracy of 83.8%, with 50 trees we got 84.4%, with 100 trees we got 96.28% and with 200 trees we got 96.33% which is where we’ve currently peaked.

The amount of time it’s taking to build the forests as we increase the number of trees is also starting to become a problem so our next step is either to look at a way to parallelise the creation of the forest or do some sort of feature extraction to try and improve the accuracy.

The code is on github if you’re interested in playing with it or have any suggestions on how to improve it.

Written by Mark Needham

October 27th, 2012 at 8:24 pm

Posted in Machine Learning

Tagged with

  • venkiram

    Is there a way to extract the structure of the trees built from Mahout? One possible way of finding the
    important variables is to look how often each one gets selected to split the data high up in a tree. Appreciate your response.

  • Shankusu

    Hi!

    I don’t understand why we have to generate dataset before loading it in mahout. Can you have any explanations about it?

  • http://www.markhneedham.com/blog Mark Needham

    @Shankusu I think that’s just the name of the function in Mahout. We are using our own data set rather than generating one – the data all comes from the CSV file.

  • Shankusu

    @markneedham:disqus: I’m trying implementing a program that can build model from train.csv and evalute model by test.csv. However, I can read data from test.csv after building model. Can you help me fix this problem?

  • Shankusu

    I’m trying implementing a program that can build model from train.csv and evaluate model by test.csv. However, I can’t load data from test.csv after building model. Can you help me fix this problem?

  • Shankusu

    @markneedham:disqus: I’m trying implementing a program that can build model from train.csv and evaluate model by test.csv. However, I can’t load data from test.csv after building model. Can you help me fix this problem?

  • Andrew H

    Thanks for the code!

    I’ve tried to run the code 3 times and it only ends up assigning a label to a subset of the test set (first run assigned 110 labels, second run assigned 611 labels, third run assigned 222 labels).

    I’m getting the following errors, any idea why this is?:

    [WARNING]

    java.lang.reflect.InvocationTargetException

    at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)

    at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:39)

    at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:25)

    at java.lang.reflect.Method.invoke(Method.java:597)

    at org.codehaus.mojo.exec.ExecJavaMojo$1.run(ExecJavaMojo.java:297)

    at java.lang.Thread.run(Thread.java:662)

    Caused by: java.lang.ArrayIndexOutOfBoundsException: -1

    at org.apache.mahout.classifier.df.DecisionForest.classify(DecisionForest.java:112)

    at capstone7.project.MahoutKaggleDigitRecognizer.main(MahoutKaggleDigitRecognizer.java:56)