Classification and Regression Trees are two types of model used for supervised learning. Classification trees are used to predict a discrete variable such as life or death; regression trees are used to predict a continuous variable such as a income. Both types of model use a very similar algorithm.
Performance Versus Interpretability
CART algorithms do not have the same predictive accuracy as other supervised machine learning algorithms. However, their advantage is that offer a result which is easy to interpret. If the objective is to build a model that facilitates understanding, then CART based models may be a good choice.
Survival on the Titanic
There were 1309 people on the Titanic: 809 people died (62%) and 500 survived. The most basic form of prediction would be choose the value of the dominant class. If asked whether a person survived or not, if you predicted death, you would be correct 62% or the time. Can we do better using a classification tree?
Missing Ages Using Regression Trees
Of the 1309 Titanic records, 263 had data missing for the “age” variable. We could ignore these observations, or perhaps, use the mean of the (1309 – 263) 1046 observations to predict the 263 missing ages. Or, we could use a predictive model. The MSE using the mean was 207.55. But when a regression tree was used, the MSE resulted in a MSE of 125.87.
Performance of the Classification Tree
In order to estimate the likely prediction error, the set of observations was split up into a training set and test set with 1,100 and 209 observations respectively. A six node model was fitted to the training set and then this model was applied to the test set. For the 209 observations, the model had an accuracy of 83%.
The R tree package was used to create and then test the predictive model. A model was first created using a training set and then this model was applied to a test set. The tree package also has functions to plot the resultant tree. This functionality is shown by the following code snippet.
#mTitanticAll contains 1309 observations #create vector with random sample of 1100 trainSet <- sample(1:nrow(mTitanicAll), 1100) #build model on training set. Dependent variable = "survived" treeTrain <- tree(survived~. , mTitanicAll, subset = trainSet) #test set defined mTitanicTest <- mTitanicAll[-trainSet, ] #make predictions using training set model, applied to test set mTitanicTestPred <- predict(treeTrain, newdata = mTitanicTest, type = "class") #create confusion matrix tab <- table(mTitanicTestPred, mTitanicTest$survived) #print accuracy (tab + tab) / sum(tab) #plot the model plot(treeTrain) text(treeTrain, pretty = 0)
CART Model Theory
The following summarises the basic theory underlying CART models.
Recursive binary partitions
Both classification and regression trees attempt to split the set of observations into two partitions. The split-point of the partition is determined by a specific value of a specific predicator variable. Maximising the predictive ability of the model is the criteria that determines a specific split-point. The basic algorithm is shown in pseudo-code below:
The basic difference between classification and regression trees is in regards to how their costs are calculated.
The cost criteria for regression is the Mean Squared Error (MSE) For example, if we are trying to predict house prices using the house size as a predictor, we iterate through the various house sizes and at each value, partition the set of observations into two. We calculate the average for the two sets and then based on this average, calculate the MSE. We choose the value of the predictor that results in the lowest aggregate MSE for the two sets.
One of two measures are commonly used, these are Shannon’s entropy or the Gini coefficient. Both attempt to measure homogeneity. For example, the entropy of a fair coin, with a 50% chance of heads or tails will have a value of 1.00. But with a weighted coin with 80% change of heads, the entropy will be reduced to 0.72.
The algorithm above is repeated and the number of observations will be smaller and smaller as the depth of the tree increases. At the limit, a tree will grown with the number of terminal nodes (i.e. leaves) equal to the number of observations. Models with this type of complexity are likely to lead to an overfitted model that will fail to generalise to alternative data sets.
To prevent overfitting, a deep tree is grown and then pruned back to a tree with a smaller depth. The goal is to minimise the following equation:
The equation above has conflicting components. The first term, the Mean Squared Error decreases as the depth of the tree increases. The second term measure the number of terminal nodes and this increases as the depth of the tree increases.
Github Code and Data
The data and the code for the examples above is available from here.