Chapter 4 BART
4.1 A BART version of our hierachical trees model
Let’s define:
- P to be the number of trees
- J to be the total number of groups
- \(\Theta\) will be the set of node hyperparameters
- \(\mu\) and \(\mu_j\) for each tree in 1 to P
We have a variable of interest for which we assume:
\[\begin{equation} y_{ij} = \sum_{p = 1}^{P} \overbrace{\mathbb{G}}^\text{Tree look up function}(\underbrace{X_{ij}}_\text{Covariates}, \overbrace{T_{p}}^\text{Tree structure}, \overbrace{\Theta_{p}}^\text{Terminal node parameters}) + \underbrace{\epsilon_{ij}}_\text{Noise} \end{equation}\]
for observation \(i = i, \dots, n_j\) in group \(j = 1, \dots, J\). We also have that:
\[\begin{equation} \epsilon_{ij} \sim N(0, \tau^{-1}), \end{equation}\]
where \(\tau^{-1}\) is the residual precision. In this setting, \(\Theta_{p}\) will represent the terminal node parameters + the individual group parameters for tree \(p\).
For a single terminal node, let:
\[\begin{equation} R_{ijp1} = Y_{ij}^{(1)} - \sum_{t \neq p} \mathbb{G}(X_{ij}^{(1)}, T_{t}, M_{t}) \end{equation}\]
which represents the partial residuals for observation i, in group j, for tree p in terminal node 1. Now, let
\[\begin{equation} \underset{\sim}{R_j} = \{R_{ij}, \dots, j = 1,\dots, J \} \end{equation}\]
then
\[\begin{equation} \underset{\sim}{R_j} \sim N(\mu_j, \tau^{-1}), \\ \mu_j \sim N(\mu, k_1\tau^{-1}/P), \text{(P = number of trees)} \\ \mu \sim N(0, k_2 \tau^{-1}/P)\\ \end{equation}\]
using the same marginalisation as for a single tree:
\[\begin{equation} \underset{\sim}{R_j} \sim MVN(\mu \mathbf{1}, \tau^{-1} (k_1MM^{T} + \mathbb{I})), \text{(M = group model matrix)}\\ \text{using the same trick as before and } \Psi = k_1 MM^{T} + \mathbb{I}: \\ \underset{\sim}{R_j} \sim MVN(0, \tau^{-1} (\Psi + k_2 \mathbf{1}\mathbf{1}^{T})), \end{equation}\]
which is used to get the marginal distribution of a new tree. The new posterior updates will be:
\[\begin{equation} \mu | \dots \sim MVN( \frac{\mathbf{1}^{T} \Psi^{-1} R }{\tau \Psi^{-1} \mathbf{1} + k_2^{-1} M^{-1}}, \tau^{-1} (\mathbf{1}^{T} \Psi^{-1} \mathbf{1} + k_2^{-1} M^{-1})), \end{equation}\]
\[\begin{equation} \mu_j | \dots \sim MVN( \frac{\mu k_1 + \bar R_j n_j}{(n_j + k_1^{-1})}, \tau^{-1} (n_j + k_1^{-1})) \end{equation}\]
The update for \(\tau\) will be a little different. Let \(\hat f_{ij}\) be the overall prediction for observation \(R_{ij}\) at the current iteration:
\[\begin{equation} \tau | \dots \sim Ga( \frac{n + P + 1}{2} + \alpha, \frac{\sum_{i, j} (y_{ij} - \hat f_{ij})^2}{2} + \frac{\sum_{j, p} (\mu_{jp} - \mu_{p})^2}{2} + \frac{\sum_{j, p} \mu_{p}^2}{2} + \beta ) \end{equation}\]