Chapter 8 BART
8.1 A BART version of our hierachical trees model
Let’s define:
- P to be the number of trees
- J to be the total number of groups
- Θ will be the set of node hyperparameters
- μ and μj for each tree in 1 to P
We have a variable of interest for which we assume:
yij=P∑p=1Tree look up function⏞G(Xij⏟Covariates,Tree structure⏞Tp,Terminal node parameters⏞Θp)+ϵij⏟Noise
for observation i=i,…,nj in group j=1,…,J. We also have that:
ϵij∼N(0,τ−1),
where τ−1 is the residual precision. In this setting, Θp will represent the terminal node parameters + the individual group parameters for tree p.
For a single terminal node, let:
Rijp1=Y(1)ij−∑t≠pG(X(1)ij,Tt,Mt)
which represents the partial residuals for observation i, in group j, for tree p in terminal node 1. Now, let
Rj∼={Rij,…,j=1,…,J}
then
Rj∼∼N(μj,τ−1),μjpl∼N(μ,k1τ−1/P),where P = number of trees, p = tree index, j = group index, l = terminal node indexμpl∼N(0,k2τ−1/P)where P = number of trees, p = tree index, l = terminal node index
with l=1,…,np, where np is the number of nodes in tree p, and ∑Pp=1np=Np.
Using the same marginalisation as for a single tree:
Rj∼∼MVN(μ1,τ−1(k1MMT+I)),(M = group model matrix)using the same trick as before and Ψ=k1MMT+I:Rj∼∼MVN(0,τ−1(Ψ+k211T)),
which is used to get the marginal distribution of a new tree. The new posterior updates will be:
μ|⋯∼N(1TΨ−1R1TΨ−11+(k2/P)−1,τ−1(1TΨ−11+(k2/P)−1)),
μj|⋯∼MVN(Pμ/k1+ˉRjnj(nj+P/k1),τ−1(nj+P/k1))
The update for τ will be a little different. Let ˆfij be the overall prediction for observation Rij at the current iteration, which is the sum of group parameters for the corresponding observation. Then:
π(τ|…)∝[ΠNi=1π(yi|τ)]×[Πj,l,pπ(μj,l,p|τ)]×[Πl,pπ(μl,p|τ)]×π(τ)∝[τN/2exp{−τ∑Ni=1(yi−ˆfi)22}]×[(τPk1)(JNp)/2exp{−(τPk1)∑j,l,p(μj,l,p−μl,p)22}]×[(τPk2)Np/2exp{−(τPk2)∑l,pμ2l,p2}]×τα−1exp{−τβ}
τ|⋯∼Ga(N+JNp+Np2+α,∑Ni=1(yi−ˆfi)22+P∑j,l,p(μj,l,p−μl,p)22k1+P∑l,pμ2l,p2k2+β)