Regression tree is one type of tree model, it applies to regression tasks, i.e. the regression version of CART (Classification And Regression Tree). It is trained by recursively splitting the sample and feature space, equivalent to fit the data with step functions, therefore it has no assumptions on data. It is also used as weak learner in Random Forest Regressor and GBDT (both classification and regression).
1. Introduction
Decision tree is a family of algorithm which are widely used in machine learning. It is not only an algorithm by itself, but also the basic unit of ensemble algorithms like random forest and boosting.
Decision tree is easy to interpret since it’s structure is similar to the decision process of human, and if you have data structures basis you’ll see that the structure is exactly binary tree. For example when the hiring manager is viewing resumes for applying data scientist position, the general procedure might be like this: If the applicants have related working experience? next, if the applicants’ skill sets match with the job description? next, if the applicants have degree in mathematics/statistics/computer science, so on so forth.
There are many decision tree algorithm, including:
- ID3 (Iterative Dichotomiser 3)
- C4.5 (successor of ID3)
- CART (Classification And Regression Tree)
the differences are mainly about learning algorithm (such as loss function). CART can be used to tackle both classification and regression problems, they are non-parametric method since no assumptions are made on the data. In this post I will briefly go through the mechanism and implementation of regression tree.
2. Mathematical Basis
2.1 Basic Elements of Decision Tree
We first define several concepts for decision trees. A tree consists of nodes, leaves and the linkage among them (node to node, node to leaf). The top node is called root, and the last level of the tree would be leaves. In each node a decision is made based on one rule, expressed as \(\{X|X_{j}<s\}\), which means the data set will split into two parts (left child and right child from the view of tree), observations with \(j\) th feature less than \(s\) will go to left child node otherwise it will go to right child node. Then we do same manipulation on left child node and right child node.
The model predicts as
\[
\hat{f}(X) = \sum^{M}_{m=1}c_{m}\mathbb{I}\left(X\in R_{m}\right).
\]
If we’re minimizing squared loss, \(c_{m}\) is the mean of the target in region \(m\).
2.2. Split Finding
The key part of learning decision tree is split, i.e. find the value of a specific feature to divide the feature space into small blocks. Then we just go straightforward: we would like the split process (building process of the tree) could minimize error in prediction, equivalently we look for blocks sequence \(R_{1}, R_{2},\cdots, R_{n}\) that could minimize error.
Since this is a regression problem, we use the mean value of all observations in one block as the prediction of observations fall into that block, and we can choose squared loss / absolute loss, etc. Here I will use squared loss as an example. Now like many other machine learning algorithms, it is an optimization problem, i.e. reduce the MSE (squared loss) by splitting,
\[
R_{1}, R_{2},\cdots, R_{n} = \arg\min_{R_{1}, R_{2},\cdots, R_{n}}\sum^{n}_{i=1}\sum_{j: x_{j}\in R_{i}}\left(y_{j}-\bar{y}_{R_{i}}\right)^{2}.
\]
Unfortunately this problem is computationally infeasible to consider every possible partition of the feature space into \(n\) blocks.
Rut those awesome scientists found an alternative way, they came up with an idea called top-down greedy recursive binary splitting, which means it begins at the top of the tree where all observations are still in one block, then continuously split the feature space where each split will generate a left child node and right child node. Each step an optimize split is selected therefore it’s “greedy”. Therefore instead of find a global optima, we turn to find a local optima (global optima in every stage). The problem would looks like this: find a split \(\{X |X_{j} < s\}\), where this split generate \(R_{1} = \{X|X_{j} < s\}\) and \(R_{2} = \{X|X_{j} \geq s\}\) such that
\[
(j, s) = \arg\min_{(j,s)}\sum_{i:x_{i}\in R_{1}}\left(y_{i}-\bar{y}_{R_{1}}\right)^{2} + \sum_{j:x_{j}\in R_{2}}\left(y_{j}-\bar{y}_{R_{2}}\right)^{2}.
\]
Then we repeat this process on data in \(R_{1}\) and \(R_{2}\) and go to the next level of the tree, it is a depth wise growth, kind of similar to BFS (Breadth First Search).
The stop condition for growing tree could be set by specifying the maximum depth max_depth
(the longest step length from root to the leaves) and minimum number of observations needed to make a split min_split_sample
, or say the number of observations need in a leaf.
3. Tree Pruning
It’s easy to make a complex tree that might has poor performance on test data - the extreme case could be each leaf only has one sample in the end. Rased on bias-variance trade-off principal, we would like to sacrifice some bias to reduce variance. One of strategies is pruning the tree - grow a large tree first then prune it back to a smaller subtree.
Pruning can be down either before the tree is built (Pre-pruning) or after the tree is built (Post-pruning):
- Pre-pruning is done by set a positive threshold for node split: only split the node when the loss reduction is more than the threshold.
- Post-pruning is also known as Cost Complexity Pruning: during the tree growing, a series of trees are generated \(T_{0}, T_{1}, \cdots, T_{m}\), where \(T_{m}\) is the root. At step \(i\), the tree is created by removing a subtree from tree \(i-1\) and replacing it with a leaf node.
In addition we can control max_depth
and min_split_sample
, or add tree size \(|T|\) as a regularization term in the loss function.
4. Implementation
The implementation of decision tree could be tricky, since you don’t know how many splits needed exactly, also it’s dynamic process since every time a new split will be applied on a new block.
First I used R
and quickly found that my R-coding is still not good enough to finish this job, I can only perform every split manually…here is my solutions and you can totally ignore this one :-)
Then I turned to Python
for help. I defined classes called Node
and RegressionTree
, and defined a recursive method to perform the split, which is better and convenient Python solution. Here is a better solution I found without using OOP, I did little revision to make it works in my situation and it out-performed than my solution… From this practice I clearly realized that I need more programming practice…
5. References
- Gareth James, Daniela Witten, Trevor Hastie, Robert Tishirani, Introduction to Statistical Learning.
- Decision Tree, https://en.wikipedia.org/wiki/Decision_tree
- Decision Tree Learning, https://en.wikipedia.org/wiki/Decision_tree_learning
- Decision Tree Pruning, https://en.wikipedia.org/wiki/Decision_tree_pruning