CART Models

Tags:  

Classification and Regression Trees are two types of model used for supervised learning. Classification trees are used to predict a discrete variable such as life or death; regression trees are used to predict a continuous variable such as a income.  Both types of model use a very similar algorithm.

Predicting survival on the Titanic using a 4 node classification tree.  Women had better chances of survival than compared to men.  If you were a 1st or 2nd class women, you had a 93% chance of survival.  If you were a man aged greater than 9.7 years old, you had a 83% chance of dying.  In general, 61% of people did not survive.  Note that on the bottom right, the algorithm picks mid points between two values.  This means that the value 9.7 did not actually occur in the data set.

Predicting survival on the Titanic using a 4 node classification tree. Women had better chances of survival than compared to men. If you were a 1st or 2nd class women, you had a 93% chance of survival. If you were a man aged greater than 9.7 years old, you had a 83% chance of dying. In general, 61% of people did not survive. Note that on the bottom right, the algorithm picks mid points between two values. This means that the value 9.7 did not actually occur in the data set.

Performance Versus Interpretability

CART algorithms do not have the same predictive accuracy as other supervised machine learning algorithms. However, their advantage is that offer a result which is easy to interpret.  If the objective is to build a model that facilitates understanding, then CART based models may be a good choice.

Survival on the Titanic

There were 1309 people on the Titanic: 809 people died (62%) and 500 survived.  The most basic form of prediction would be choose the value of the dominant class.  If asked whether a person survived or not, if you predicted death, you would be correct 62% or the time.  Can we do better using a classification tree?

Missing Ages Using Regression Trees

Of the 1309 Titanic records, 263 had data missing for the “age” variable.  We could ignore these observations, or perhaps, use the mean of the (1309 – 263)  1046 observations to predict the 263 missing ages.  Or, we could use a predictive model.  The MSE using the mean was 207.55.  But when a regression tree was used, the MSE resulted in a MSE of 125.87.

Performance of the Classification Tree

In order to estimate the likely prediction error, the set of observations was split up into a training set and test set with 1,100 and 209 observations respectively. A six node model was fitted to the training set and then this model was applied to the test set. For the 209 observations, the model had an accuracy of 83%.

The correct predictions are shown in blue. In the holdout sample, there are 209 observations. Of these, 114 people perished; 95 people survived. Assuming that the person, perished, the model was 93% accurate. Assuming that the person survived, the model was 72% accurate. The total accuracy of the model evaluated on the holdout sample was 83%. This is likely to be a pessimistic estimate as 209 observations (or about 15%) of the total observations were not used.

The correct predictions are shown in blue. In the holdout sample, there are 209 observations. Of these, 114 people perished; 95 people survived. Assuming that the person, perished, the model was 93% accurate. Assuming that the person survived, the model was 72% accurate. The total accuracy of the model evaluated on the holdout sample was 83%. This is likely to be a pessimistic estimate as 209 observations (or about 15%) of the total observations were not used.

Implementation

The R tree package was used to create and then test the predictive model. A model was first created using a training set and then this model was applied to a test set.  The tree package also has functions to plot the resultant tree. This functionality is shown by the following code snippet.

#mTitanticAll contains 1309 observations
#create vector with random sample of 1100
trainSet <- sample(1:nrow(mTitanicAll), 1100)
#build model on training set. Dependent variable = "survived"
treeTrain <- tree(survived~. , mTitanicAll, subset = trainSet)
#test set defined
mTitanicTest <- mTitanicAll[-trainSet, ]
#make predictions using training set model, applied to test set
mTitanicTestPred <- predict(treeTrain,
newdata = mTitanicTest, type = "class")
#create confusion matrix
tab <- table(mTitanicTestPred, mTitanicTest$survived)
#print accuracy
(tab[1] + tab[4]) / sum(tab)
#plot the model
plot(treeTrain)
text(treeTrain, pretty = 0)

CART Model Theory

The following summarises the basic theory underlying CART models.

Recursive binary partitions

Both classification and regression trees attempt to split the set of observations into two partitions. The split-point of the partition is determined by a specific value of a specific predicator variable.  Maximising the predictive ability of the model is the criteria that determines a specific split-point.  The basic algorithm is shown in pseudo-code below:

Cost Calculation

The basic difference between classification and regression trees is in regards to how their costs are calculated.

Regression Costs

The cost criteria for regression is the Mean Squared Error (MSE) For example, if we are trying to predict house prices using the house size as a predictor, we iterate through the various house sizes and at each value, partition the set of observations into two.  We calculate the average for the two sets and then based on this average, calculate the MSE.  We choose the value of the predictor that results in the lowest aggregate MSE for the two sets.

Classification Costs

One of two measures are commonly used, these are Shannon’s entropy or the Gini coefficient.  Both attempt to measure homogeneity.  For example, the entropy of a fair coin, with a 50% chance of heads or tails will have a value of 1.00.  But with a weighted coin with 80% change of heads, the entropy will be reduced to 0.72.

Overfitting

The algorithm above is repeated and the number of observations will be smaller and smaller as the depth of the tree increases.  At the limit, a tree will grown with the number of terminal nodes (i.e. leaves) equal to the number of observations. Models with this type of complexity are likely to lead to an overfitted model that will fail to generalise to alternative data sets.

To prevent overfitting, a deep tree is grown and then pruned back to a tree with a smaller depth.  The goal is to minimise the following equation:

CART-CostEquation

The equation above has conflicting components. The first term, the Mean Squared Error decreases as the depth of the tree increases. The second term measure the number of terminal nodes and this increases as the depth of the tree increases.

Github Code and Data

The data and the code for the examples above is available from here.

OctocatSmall