
Tune and fit a random-forest DTR policy model with (optional) cross-validation
rfdtr.RdFits a classification-based dynamic treatment regime (DTR) policy model for a binary treatment
learns a binary decision rule \(\hat d(X)\) with values in \(\{-1,+1\}\) using either ranger or randomForestSRC. The function
performs grid search over random-forest tuning parameters and, when usecv=TRUE, evaluates
candidate policies via K-fold cross-validation using:
CCR: classification correctness rate (accuracy of predicted treatment labels),
OOB: out-of-bag prediction error (model-reported), and
Score: a user-supplied policy-value-like performance measure computed by
my_score.Surv()using observed and matched pseudo-outcomes.
The best tuning parameters are selected according to metric, after which a final model is fit
on the full observed dataset obs. The output includes the final fitted model, the estimated
treatment rule estA.obs, and the full tuning results.
Usage
rfdtr(
modeltype = "ranger",
usecv = TRUE,
sl.seed = 123,
obs,
W,
gridpar,
metric = "ccr",
A.obs,
Q.obs,
Q.match,
score_agg = c("mean", "sum")
)Arguments
- modeltype
Character. Random-forest engine:
"ranger"or"rfsrc".- usecv
Logical. If TRUE, uses 5-fold CV to tune hyperparameters; otherwise tunes on the full data.
- sl.seed
Integer. Seed for reproducibility. (Note: parts of the function currently hard-code
set.seed(123).)- obs
data.frame. Training dataset containing a column
Aand covariates for predictingA.- W
Numeric vector. Case weights of length
nrow(obs).- gridpar
data.frame. Grid of tuning parameters. Expected columns include:
ntree,mtry, andnodesize. Extra columns are ignored by the model fits.- metric
Character. Criterion used to choose the “best” tuning row. Supported:
"oob": minimizeOOB"policyval"/"score"/"policy"/"val": maximizeScoreotherwise: maximize
CCR
- A.obs
Numeric vector. Observed treatment labels (
-1/1) aligned withobs.- Q.obs
Numeric vector. Observed pseudo-outcome or value component aligned with
obs.- Q.match
Numeric vector. Matched/pair pseudo-outcome aligned with
obsused bymy_score.Surv().- score_agg
Character. Aggregation for fold-level
Score:"sum"or"mean".
Value
A list with elements:
model: final fitted ranger or randomForestSRC object trained on allobs.estA.obs: numeric vector of predicted treatment decisions (-1/1) for each row ofobs.tune: data.frame of tuning results for each row ofgridpar, includingCCR,OOB,Score.best: the selected row oftunecorresponding to the chosenmetric.
Details
Inputs and data layout.
obsis a data.frame that must contain a column namedAand the covariates used to modelA. The function coercesobs$Ato a factor with levelsc(-1, 1)to ensure a binary classification target.Wis a numeric vector of case weights aligned toobs(e.g., IPC weights).A.obsis the numeric observed treatment indicator (typically-1/1) aligned toobs.Q.obsandQ.matchare numeric vectors aligned toobsused bymy_score.Surv()to compute the policy score (e.g., observed pseudo-outcome and matched/pair pseudo-outcome).
Cross-validation logic.
When usecv=TRUE, the function uses 5-fold stratified folds created by caret::createFolds(obs$A, k=5).
For each candidate parameter set in gridpar, it fits the model on the training folds and evaluates on the
held-out fold:
CCR: mean(
predicted_class == obs$A[test]).OOB: model-reported OOB error (for ranger,
mod$prediction.error; for rfsrc, the last entry ofmod$err.rate[,"all"]).Score:
my_score.Surv(pred, A.test, Q.test, Q.match.test)wherepredis the numeric treatment rule-1/1.
Aggregation of score across folds.
The per-fold Score values are aggregated across CV folds using:
score_agg="sum": sum of fold scores (ignoringNA)score_agg="mean": mean of fold scores (ignoringNA)
This controls the scale used for model selection when metric targets policy value.
Parameter guards.
To avoid invalid random-forest hyperparameters, if gridpar includes mtry, it is clamped
to [1, p] where p = ncol(obs)-1 (i.e., all columns except A). Duplicate parameter
rows are removed via unique(gridpar).
Parallelization.
The function is written using foreach with %dopar%. It assumes a parallel backend has
already been registered (e.g., via doParallel). Within each fold, num.threads=1 is used
for ranger to avoid nested parallelism.