# Set up for const tau example
set.seed(1)
<- 1/2
tau_truth <- 1000
n
# randomized treatment
<- rbinom(n = n, size = 1, prob = 1/2)
W
# continuous covariate
<- rnorm(n = n)
X
# outcome
<- tau_truth * W + abs(X) + rnorm(n)
Y
<- data.frame(Y=Y, X=X, W=W) data
Causal random forests with grf
MELODEM data workshop
Overview
Robinson’s residual-on-residual regression
Causal trees
Causal random forest
Causal random survival forest
Inference with causal random forests
Conditional average treatment effect (CATE)
Best linear projection (BLP)
Rank weighted average treatment effect (RATE)
Robinson’s residual-on-residual regression
The partially linear model
Suppose
\[ Y_i = \tau W_i + f(X_i) + \varepsilon_i \] Assume:
\(E[\varepsilon_i | X_i, W_i] = 0\)
untreated outcome is given by unknown function \(f\),
a treatment assignment shifts the outcome by \(\tau\).
How to estimate \(\tau\)?
Suppose
\[ Y_i = \tau W_i + f(X_i) + \varepsilon_i \]
How do we estimate \(\tau\) when we do not know \(f(X_i)\)?
Define:
\[\begin{align*} e(x) &= E[W_i | X_i=x] \,\, \text{(Propensity score)} \\ m(x) &= E[Y_i | X_i = x] = f(x) + \tau e(x) \,\,\,\,\, \text{(Cndl. mean of } Y\text{)} \end{align*}\]
Use propensity and conditional mean
Re-express the partial linear model in terms of \(e(x)\) and \(m(x)\):
\[\begin{align*} Y_i &= \tau W_i + f(X_i) + \varepsilon_i, \, \\ Y_i - \tau e (x) &= \tau W_i + f(X_i) - \tau e(x) + \varepsilon_i, \, \\ Y_i - f(X_i) - \tau e (x) &= \tau W_i - \tau e (x) + \varepsilon_i, \, \\ Y_i - m(x) &= \tau (W_i - e(x)) + \varepsilon_i, \, \\ \end{align*}\]
\(\tau\) can be estimated with residual-on-residual regression (Robinson 1988).
How? Plug in flexible estimates of \(m(x)\) and \(e(x)\)
Re-write as a linear model
More formally,
\[ \hat{\tau} := \text{lm}\Biggl( Y_i - \hat m^{(-i)}(X_i) \sim W_i - \hat e^{(-i)}(X_i)\Biggr). \]
Superscript \(^i\) denotes cross-fit estimates (Chernozhukov et al. 2018).
Cross-fitting: estimate something, e.g., \(e(x)\), using cross-validation.
Why? removes bias from over-fitting.
Example
Suppose \(Y_i = \tau W_i + f(X_i) + \epsilon_i\)
- \(\tau\) is 1/2
- \(W_i\) is randomized treatment
- \(X_i\) is a continuous covariate
- \(f(X_i) = |X_i|\)
By defn, \(E[W_i] = 1/2\) (why?).
Example done wrong
First we’ll do it the wrong way.
Fit a classical model to estimate conditional mean of \(Y\).
Compute residuals and run Robinson’s regression.
What’d we do wrong?
library(glue)
library(ggplot2)
<- lm(Y ~ X, data = data)
fit_cmean
<- predict(fit_cmean, new_data = data)
m_x
<- Y - m_x
resid_y <- W - 1/2
resid_w
<- lm(resid_y ~ resid_w)
tau_fit
glue("True tau is {tau_truth}, \\
estimated tau is {coef(tau_fit)[2]}")
True tau is 0.5, estimated tau is 0.462441735905258
Conditional mean predictions…
Example done wrong, take 2
The model for conditional mean was under-specified.
Fit a flexible model to estimate conditional mean of \(Y\).
Compute residuals and run Robinson’s regression.
What’d we do wrong?
library(aorsf)
<- orsf(Y ~ X, data = data)
fit_cmean
<- predict(fit_cmean, new_data = data)
m_x
<- Y - m_x
resid_y <- W - 1/2
resid_w
<- lm(resid_y ~ resid_w)
tau_fit
glue("True tau is {tau_truth}, \\
estimated tau is {coef(tau_fit)[2]}")
True tau is 0.5, estimated tau is 0.394643899391371
Example done right
We forgot about cross-fitting!
Fit a flexible model to estimate conditional mean of \(Y\).
Use out-of-bag predictions.
Compute residuals and run Robinson’s regression.
<- predict(fit_cmean, oobag = TRUE)
m_x_oobag
<- Y - m_x_oobag
resid_y <- W - 1/2
resid_w
<- lm(resid_y ~ resid_w)
tau_fit
glue("True tau is {tau_truth}, \\
estimated tau is {coef(tau_fit)[2]}")
True tau is 0.5, estimated tau is 0.495509155402123
Conditional mean predictions
Intuition of Robinson’s regression
\(Y - E[Y|X]\) is a treatment centered Y.
\(W - E[W|X]\) is the usual binary treatment effect in randomized trials
Intuition of Robinson’s regression
\(W - E[W|X]\) is a continuous exposure for the treatment in observational data
Centered \(W\) impacts how much a person with covariates \(X\) influences the \(\tau\) estimate
Causal trees
How to grow causal trees
Causal trees are much like standard decision trees, but they maximize
\[n_L \cdot n_R \cdot (\hat{\tau}_L-\hat{\tau}_R)^2\]
where residual-on-residual regression is used to estimate \(\hat{\tau}_L\) and \(\hat{\tau}_R\)
grf
estimates \(\hat \tau\) once in the parent node and uses “influence functions” to approximate how \(\hat\tau\) would change if an observation moved from one child node to the other (Wager and Athey 2018).Predictions from leaves are \(E[Y|W=1] - E[Y|W=0]\)
How to grow causal trees
Causal trees use “honesty” and “subsampling” (Wager and Athey 2018).
Honesty: Each training observation is used for one of the following:
Estimate the treatment effect for leaf nodes.
Decide splitting values for non-leaf nodes.
Subsampling: While Breiman (2001)’s random forest uses bootstrap sampling with replacement, the causal random forest samples without replacement.
Intuition of causal trees
- Everyone starts at the root
Intuition of causal trees
Assess splits by running Robinson’s regression in each child node.
Remember: good splits maximize \[n_L \cdot n_R \cdot (\hat{\tau}_L-\hat{\tau}_R)^2\]
Is this a good split?
Intuition of causal trees
Assess splits by running Robinson’s regression in each child node.
Remember: good splits maximize \[n_L \cdot n_R \cdot (\hat{\tau}_L-\hat{\tau}_R)^2\]
Is this a good split?
Causal random forest
Back to the partial linear model
Relaxing the assumption of a constant treatment:
\[ Y_i = \color{red}{\tau(X_i)} W_i + f(X_i) + \varepsilon_i, \, \]
where \(\color{red}{\tau(X_i)}\) is the conditional average treatment (CATE). If we had a neighborhood \(\mathcal{N}(x)\) where \(\tau\) was constant, then we could do residual-on-residual regression in the neighborhood:
\[ \hat\tau_i(x) := lm\Biggl( Y_i - \hat m^{(-i)}(X_i) \sim W_i - \hat e^{(-i)}(X_i), \color{red}{w = 1\{X_i \in \mathcal{N}(x) \}}\Biggr), \]
Random forest adaptive neighborhoods
Suppose we fit a random forest with \(B\) trees to a training set of size \(n\), and we compute a prediction \(p\) for a new observation \(x\):
\[\begin{equation*} \begin{split} & p = \sum_{i=1}^{n} \frac{1}{B} \sum_{b=1}^{B} Y_i \frac{1\{Xi \in L_b(x)\}} {|L_b(x)|} \end{split} \end{equation*}\]
\(L_b(x)\) indicates the leaf node that \(x\) falls into for tree \(b\)
The inner sum is the mean of outcomes in the same leaf as \(x\)
This generalizes to causal random forests (it’s easier to write with regression trees).
Random forest adaptive neighborhoods
Pull \(Y_i\) out of the sum that depends on \(b\):
\[\begin{equation*} \begin{split} p &= \sum_{i=1}^{n} \frac{1}{B} \sum_{b=1}^{B} Y_i \frac{1\{Xi \in L_b(x)\}} {|L_b(x)|} \\ &= \sum_{i=1}^{n} Y_i \sum_{b=1}^{B} \frac{1\{Xi \in L_b(x)\}} {B \cdot |L_b(x)|} \\ & = \sum_{i=1}^{n} Y_i \color{blue}{\alpha_i(x)}, \end{split} \end{equation*}\]
- \(\alpha_i(x) \propto\) no. of times observation \(i\) lands in the same leaf as \(x\)
Intuition of forest weights
Start with a “new” data point: Bill
Initialize leaf counts and no. of trees as 0.
Intuition of forest weights
- Drop Bill down the first tree
Intuition of forest weights
- Drop Bill down the first tree
- Update counts
Intuition of forest weights
- Drop Bill down the second tree
- Update counts
Intuition of forest weights
- Compute weights
- Note some are 0
Intuition of forest weights
- These are Bill’s “neighbors”
- Some are more neighbor than others
Intuition of forest weights
- Now run Robinson’s regression with these weights
- Bill gets his own \(\hat\tau\) from this.
Plug weights in to lm
Instead of defining neighborhood boundaries, weight by similarity:
\[ \hat\tau_i(x) := \text{lm}\Biggl( Y_i - \hat m^{(-i)}(X_i) \sim W_i - \hat e^{(-i)}(X_i), w = \color{blue}{\alpha_i(x)} \Biggr). \] This forest-localized version of Robinson’s regression, paired with honesty and subsampling, gives asymptotic guarantees for estimation and inference (Wager and Athey 2018):
Pointwise consistency for the true treatment effect.
Asymptotically Gaussian and centered sampling distribution.
Three main components
The procedure to estimate \(\hat\tau_i\) has three pieces:
\[ \hat\tau_i(x) := \text{lm}\Biggl( Y_i - \color{green}{\hat m^{(-i)}(X_i)} \sim W_i - \color{red}{\hat e^{(-i)}(X_i)}, w = \color{blue}{\alpha_i(x)} \Biggr). \]
\(\color{green}{\hat m^{(-i)}(X_i)}\) is a flexible, cross-fit estimate for \(E[Y|X]\)
\(\color{red}{\hat e^{(-i)}(X_i)}\) is a flexible, cross-fit estimate for \(E[W|X]\)
\(\color{blue}{\alpha_i(x)}\) are the similarity weights from a causal random forest
Causal random survival forest
Set up
Assume the survival setting:
\[\begin{equation} Y_i = \begin{cases} T_i & \text{if } \, T_i \leq C_i \\ C_i & \text{otherwise} \end{cases} \end{equation}\]
Where \(T_i\) is time to event and (\(C_i\)) is time to censoring. Define
\[\begin{equation} D_i = \begin{cases} 1 & \text{if } \, T_i \leq C_i \\ 0 & \text{otherwise.} \end{cases} \end{equation}\]
Observed time versus true time
Event times are obscured by
censoring
end of follow-up, i.e., \(h\)
Observed time versus true time
Event times are obscured by
censoring
end of follow-up, i.e., \(h\)
Estimate restricted mean survival time (RMST): \(E \left[ \text{min}(T, h) \right]\). See Cui et al. (2023) for more details on adjustment for censoring.
Treatment effects for survival
Two treatment effects can be estimated conditional on \(h\).
- RMST \[\tau(x) = E[\min(T(1), h) - \min(T(0), h) \, | X = x],\]
- Survival probability: \[\tau(x) = P[T(1) > h \, | X = x] - P[T(0) > h \, | X = x].\] \(T(1)\) and \(T(0)\) are treated and untreated event times, respectively.
Inference with causal random forests
Summaries of CATEs
You could compute average treatment effect (ATE) as the mean of CATEs:
\[\hat\tau = \frac{1}{n}\sum_{i=1}^n \hat\tau_i(x)\] But the augmented inverse probability weighted ATE is better:
\[ \hat \tau_{AIPW} = \frac{1}{n} \sum_{i=1}^{n}\left( \overbrace{\tau(X_i)}^{\text{Initial estimate}} + \overbrace{\frac{W_i - e(X_i)}{e(X_i)[1 - e(X_i)]}}^{\text{debiasing weight}} \cdot \overbrace{\left(Y_i - \mu(X_i, W_i)\right)}^{\text{residual}} \right) \]
Summaries of CATEs contd.
For simplicity, re-write the augmented inverse probability ATE as
\[\hat\tau = \frac{1}{n}\sum_{i=1}^n \hat\Gamma_i(x),\] With a vector of these \(\Gamma_i\)’s, define:
Average treatment effect (ATE) =
mean(gamma)
Best linear projection (BLP) =
lm(gamma ~ X)
Estimating the ATE
First, we’ll prepare the data:
# loads packages and R functions
::tar_load_globals()
targets
# loads a specific target
tar_load(data_melodem)
Second, coerce data to grf
format:
# helper function for grf data prep
<- data_coerce_grf(data_melodem$values)
data_grf
# just a view of the X matrix
head(data_grf$X, n = 2)
statin aspirin egfr sub_ckd age sub_cvd
[1,] -0.8994367 -1.0374288 0.8063765 -0.5983451 -1.4119699 -0.4897294
[2,] 1.1116507 0.9637867 0.4961700 -0.5983451 -0.8576657 -0.4897294
CHR GLUR HDL TRR UMALCR BMI
[1,] 0.4445567 0.07162673 -0.04571026 -0.6775012 -0.1185727 0.3706294
[2,] -0.2114133 -0.67907349 -0.67458295 1.0027111 -0.1936573 -1.0884217
sbp dbp fr_risk10yrs orth_hypo education
[1,] -0.40504337 0.58627338 -0.7134939 -0.2760562 -0.07626949
[2,] 0.05247806 0.05612061 -0.3414115 -0.2760562 -0.07626949
frailty_catg_Frail frailty_catg_Pre.frail frailty_catg_unknown
[1,] 0 1 0
[2,] 0 1 0
race4_HISPANIC race4_OTHER race4_WHITE race4_unknown sex_female
[1,] 0 0 0 0 1
[2,] 0 0 1 0 1
sex_unknown
[1,] 0
[2,] 0
Fitting the forest
Third, fit the causal survival forest:
<- causal_survival_forest(
fit_grf X = data_grf$X, # covariates
Y = data_grf$Y, # time to event
W = data_grf$W, # treatment status
D = data_grf$D, # event status
horizon = 3, # 3-year horizon
# treatment effect will be
# measured in terms of the
# restricted mean survival time
target = 'RMST'
)
fit_grf
GRF forest object of type causal_survival_forest
Number of trees: 2000
Number of training samples: 7159
Variable importance:
1 2 3 4 5 6 7 8 9 10 11 12 13
0.003 0.002 0.064 0.000 0.117 0.012 0.067 0.060 0.047 0.033 0.154 0.117 0.027
14 15 16 17 18 19 20 21 22 23 24 25 26
0.083 0.112 0.006 0.050 0.021 0.001 0.000 0.020 0.000 0.005 0.000 0.001 0.000
Get \(\Gamma\) scores
Remember the \(\Gamma_i\)’s that provide conditional estimates of \(\tau\)? Let’s get them.
# pull the augmented CATEs from the fitted grf object
<- get_scores(fit_grf) gammas
- With
gammas
, we can compute ATE manually
mean(gammas)
[1] -0.001035183
- Verify this is what the
grf
function gives
average_treatment_effect(fit_grf)
estimate std.err
-0.001035183 0.002502688
Your turn
Open classwork/04-causal_forests.qmd
and complete Exercise 1
Reminder: Pink sticky note for help, blue sticky when you finish.
Note: the
data_coerce_grf()
function can save you lots of time.
05:00
Estimating the BLP
The BLP (Semenova and Chernozhukov 2021):
Is estimated by regressing a set of covariates on \(\Gamma\).
Can be estimated for a subset of covariates
Can be estimated for a subset of observations.
Summarizes heterogeneous treatment effects conditional on covariates.
You can estimate BLP manually:
<- bind_cols(gamma = gammas, data_grf$X)
data_blp
<- lm(gamma ~ ., data = data_blp) fit_blp
What happens underneath the grf
hood
Here’s how you can replicate grf
results:
::coeftest(fit_blp,
lmtestvcov = sandwich::vcovCL,
type = 'HC3')
t test of coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 0.00439932 0.00632648 0.6954 0.4868
statin -0.00213497 0.00287177 -0.7434 0.4572
aspirin 0.00314203 0.00281955 1.1144 0.2652
egfr 0.00040709 0.00330155 0.1233 0.9019
sub_ckd 0.00292551 0.00369910 0.7909 0.4290
age 0.00393674 0.00477607 0.8243 0.4098
sub_cvd -0.00423127 0.00310080 -1.3646 0.1724
CHR 0.00312453 0.00412521 0.7574 0.4488
GLUR -0.00380898 0.00272365 -1.3985 0.1620
HDL 0.00101039 0.00350525 0.2882 0.7732
TRR 0.00173615 0.00296643 0.5853 0.5584
UMALCR -0.00324024 0.00245118 -1.3219 0.1862
BMI -0.00205123 0.00257303 -0.7972 0.4254
sbp 0.00015949 0.00378842 0.0421 0.9664
dbp 0.00324218 0.00410429 0.7899 0.4296
fr_risk10yrs -0.00351058 0.00654447 -0.5364 0.5917
orth_hypo 0.00090846 0.00265606 0.3420 0.7323
education -0.00474135 0.00331396 -1.4307 0.1526
frailty_catg_Frail -0.00949440 0.00874185 -1.0861 0.2775
frailty_catg_Pre.frail -0.00305295 0.00501740 -0.6085 0.5429
race4_HISPANIC -0.00548466 0.01196162 -0.4585 0.6466
race4_OTHER 0.02118089 0.01405513 1.5070 0.1319
race4_WHITE 0.00701133 0.00652185 1.0751 0.2824
sex_female -0.01388631 0.00873421 -1.5899 0.1119
best_linear_projection(fit_grf, data_grf$X)
Best linear projection of the conditional average treatment effect.
Confidence intervals are cluster- and heteroskedasticity-robust (HC3):
Estimate Std. Error t value Pr(>|t|)
(Intercept) 0.00439932 0.00632648 0.6954 0.4868
statin -0.00213497 0.00287177 -0.7434 0.4572
aspirin 0.00314203 0.00281955 1.1144 0.2652
egfr 0.00040709 0.00330155 0.1233 0.9019
sub_ckd 0.00292551 0.00369910 0.7909 0.4290
age 0.00393674 0.00477607 0.8243 0.4098
sub_cvd -0.00423127 0.00310080 -1.3646 0.1724
CHR 0.00312453 0.00412521 0.7574 0.4488
GLUR -0.00380898 0.00272365 -1.3985 0.1620
HDL 0.00101039 0.00350525 0.2882 0.7732
TRR 0.00173615 0.00296643 0.5853 0.5584
UMALCR -0.00324024 0.00245118 -1.3219 0.1862
BMI -0.00205123 0.00257303 -0.7972 0.4254
sbp 0.00015949 0.00378842 0.0421 0.9664
dbp 0.00324218 0.00410429 0.7899 0.4296
fr_risk10yrs -0.00351058 0.00654447 -0.5364 0.5917
orth_hypo 0.00090846 0.00265606 0.3420 0.7323
education -0.00474135 0.00331396 -1.4307 0.1526
frailty_catg_Frail -0.00949440 0.00874185 -1.0861 0.2775
frailty_catg_Pre.frail -0.00305295 0.00501740 -0.6085 0.5429
race4_HISPANIC -0.00548466 0.01196162 -0.4585 0.6466
race4_OTHER 0.02118089 0.01405513 1.5070 0.1319
race4_WHITE 0.00701133 0.00652185 1.0751 0.2824
sex_female -0.01388631 0.00873421 -1.5899 0.1119
Your turn
Complete Exercise 2
05:00
Rank-Weighted Average Treatment Effect
While ATE and BLP are helpful, they do not tell us the following:
How good is a treatment prioritization rule at distinguishing sub-populations with different conditional treatment effects?
Is there any heterogeneity present in a conditional treatment effect?
The rank-weighted average treatment effect (RATE) answers both of these.
Treatment prioritization rules
Suppose we have a treatment that benefits some (not all) adults. Who should initiate treatment? A treatment prioritization rule can help:
high score for those likely to benefit from treatment.
low score for those likely to have a small/negative benefit from treatment.
Risk prediction models can be a treatment prioritization rule:
- Initiate antihypertensive medication if predicted risk for cardiovascular disease is high.
How to evaluate treatment prioritization
The basic idea:
Chop the population up into subgroups based on the prioritization rule, e.g., by decile of score.
Estimate the ATE in each group, separately, and compare to the overall estimated ATE from treating everyone
Plot the difference between group-specific ATE and the overall ATE for each of the groups
Example: the Targeting Operator Characteristic (TOC)
Targeting Operator Characteristic (TOC)
Create groups by including the top q\(^\text{th}\) fraction of individuals with the largest prioritization score.
Use many values of \(q\) to make the pattern more curve-like
Motivation: Receiver Operating Characteristic (ROC) curve, a widely used metric for assessing discrimination of predictions.
RATE: area underneath TOC
RATE is estimated by taking the area underneath the TOC curve.
\[\textrm{RATE} = \int_0^1 \textrm{TOC}(q) dq .\]
As \(\tau(X_i)\) approaches a constant, RATE approaches 0.
RATE of prediction model
Let’s use RATE to see how well risk prediction works as a treatment prioritization rule.
library(aorsf)
<- orsf(time + status ~ .,
fit_orsf data = data_melodem$values,
na_action = 'impute_meanmode',
oobag_pred_horizon = 3)
# important to use oobag!
<- as.numeric(fit_orsf$pred_oobag)
prd_risk
<-
prd_rate rank_average_treatment_effect(fit_grf,
priorities = prd_risk)
Not good
Estimating RATE from CATE
An intuitive way to assign treatment priority is to use the CATE: \(\hat\tau(X_i)\)
\(\hat\tau(X_i)\) should not be estimated and evaluated using the same data
Use split-sample estimation or cross-fitting (Yadlowsky et al. 2021).
As a preliminary step, we’ll split our data in to training and testing sets
<- sample(x = nrow(data_melodem$values),
train_index size = nrow(data_melodem$values)/2)
<- data_melodem$values[train_index, ]
data_trn <- data_melodem$values[-train_index, ]
data_tst
<- data_coerce_grf(data_trn)
data_trn_grf <- data_coerce_grf(data_tst) data_tst_grf
Fitting
We fit one forest with training data to estimate CATE and fit another forest with testing data to evaluate the CATE estimates:
<- causal_survival_forest(
fit_trn_grf X = data_trn_grf$X, Y = data_trn_grf$Y,
W = data_trn_grf$W, D = data_trn_grf$D,
horizon = 3, target = 'RMST'
)
<- causal_survival_forest(
fit_tst_grf X = data_tst_grf$X, Y = data_tst_grf$Y,
W = data_tst_grf$W, D = data_tst_grf$D,
horizon = 3, target = 'RMST'
)
Predicting
We use the forest fitted to training data to estimate CATE for the testing data. For illustration, we also estimate naive CATE
# the fitted forest hasn't
# seen the testing data
<- fit_trn_grf %>%
tau_hat_split predict(data_tst_grf$X) %>%
getElement("predictions")
# Illustration only (don't do this)
<- fit_trn_grf %>%
tau_hat_naive predict(data_trn_grf$X) %>%
getElement("predictions")
Evaluating
We use the forest fitted to the testing data to evaluate the CATE estimates for observations in the testing data
<-
rate_split rank_average_treatment_effect(
forest = fit_tst_grf,
priorities = tau_hat_split,
target = "AUTOC"
)
# Illustration only (don't do this)
<-
rate_naive rank_average_treatment_effect(
forest = fit_trn_grf,
priorities = tau_hat_naive,
target = "AUTOC"
)
Correct versus overly optimistic
The problem with being overly optimistic is it has very high type 1 error
Your turn
Complete Exercise 4
05:00
To the pipeline
- Copy/paste this code into your
_targets.R
file.
# in _targets.R, you should see this comment:
# real data model targets (to be added as an exercise).
# Paste this code right beneath that.
<- tar_target(
grf_shareable_zzzz_tar
grf_shareable_zzzz,grf_summarize(fit_grf_zzzz)
)
- Modify this code, replacing
zzzz
with the name of your dataset.