January 5, 2022 | 27

Implementation of a decision tree analysis

In this second installment of our three-part series on decision tree modelling, we’ll jump right into some analysis using R software . If you aren’t familiar with decision trees, part 1 of our series provides an introduction to decision tree theory including many of the concepts used in our analysis.

A number of R packages exist that provide functions useful for decision tree analysis such as tree , rpart , and C5.0 . However, we’ll be employing the tidymodels ensemble of packages.

Steps to building a decision tree model

  1. Inspect the data and do some cleaning
  2. Split the data into training and test sets
  3. Define a model
  4. Fit the model to the training data
  5. Tune some hyperparameters
  6. Select the best hyperparameter values and finalize the model
  7. Test and review the final model
  8. Examine model performance

Software requirements

R and RStudio

We are using R version 4.1.2 with RStudio 1.4.1.

R packages

Note that tidymodels imports a number of packages including dplyr and ggplot2 which will also be used in our analyses.

library(stringr)
library(tidymodels)
library(probably)
library(rpart.plot)
library(doParallel)
library(vip)

Decision tree analysis of stroke prediction data

Step 1. Inspect and clean data

For this analysis we’ll be using a stroke dataset downloadable from Kaggle with data on 5110 patients and 12 variables.

A call to str() will show us the structure of the dataset. The outcome, stroke, is a binary categorical variable and we can see that most of the predictors are also categorical except for age, bmi, and avg_glucose_level which are continuous. However, the column formats are not always consistent with the predictor types. Most notably, bmi, which should be numeric, is a character variable with “N/A” mixed with the values.

stroke <- read.csv("healthcare-dataset-stroke-data.csv")
str(stroke)
## 'data.frame':	5110 obs. of  12 variables:
## $ id               : int  9046 51676 31112 60182 1665 56669 53882 10434 27419 60491 ...
## $ gender           : chr  "Male" "Female" "Male" "Female" ...
## $ age              : num  67 61 80 49 79 81 74 69 59 78 ...
## $ hypertension     : int  0 0 0 0 1 0 1 0 0 0 ...
## $ heart_disease    : int  1 0 1 0 0 0 1 0 0 0 ...
## $ ever_married     : chr  "Yes" "Yes" "Yes" "Yes" ...
## $ work_type        : chr  "Private" "Self-employed" "Private" "Private" ...
## $ Residence_type   : chr  "Urban" "Rural" "Rural" "Urban" ...
## $ avg_glucose_level: num  229 202 106 171 174 ...
## $ bmi              : chr  "36.6" "N/A" "32.5" "34.4" ...
## $ smoking_status   : chr  "formerly smoked" "never smoked" "never smoked" "smokes" ...
## $ stroke           : int  1 1 1 1 1 1 1 1 1 1 ...

A temporary conversion of categorical variables to factor format and a call to summary shows us the levels of these predictors.

select(stroke, 
  gender, 
  hypertension, 
  heart_disease, 
  ever_married, 
  work_type, 
  Residence_type, 
  smoking_status
) %>% 
  mutate(across(everything(), factor)) %>% 
  summary()
##     gender     hypertension heart_disease ever_married         work_type   
##  Female:2994   0:4612       0:4834        No :1757     children     : 687  
##  Male  :2115   1: 498       1: 276        Yes:3353     Govt_job     : 657  
##  Other :   1                                           Never_worked :  22  
##                                                        Private      :2925  
##                                                        Self-employed: 819  
##  Residence_type         smoking_status
##  Rural:2514     formerly smoked: 885  
##  Urban:2596     never smoked   :1892  
##                 smokes         : 789  
##                 Unknown        :1544                                                         

Notice the “Unknown” category in smoking_status. In general, replacing missing values with an additional category is not recommended (PDF, 2.7 MB) as grouping together observations missing on some variable will lead to bias if the missingness is correlated to the outcome or any other predictor.

Roughly 30% of patients in our data set are missing a value for smoking_status. Quickly plotting age by smoking status shows a large spike in proportion of patients with unknown smoking status in the youngest age range.

pa_pal <- colorRampPalette(c("#A30664", "#E6B8D4", "#C6A1D1", "#A085CE", "#9FBBFC"))

n_smoking_levels <- length(unique(stroke$smoking_status))

ggplot(stroke, aes(x = age, colour = smoking_status)) +
  geom_density() + 
  scale_colour_manual(values = pa_pal(n_smoking_levels)) +
  theme_minimal()
A line graph describing the proportion of patients by smoking status and age. The X-axis represents continuous age beginning at zero and the Y-axis represents the proportion of of patients. There are four lines; one for each of the smoking status categories 'never smoked', 'formerly smoked', 'smokes' and 'unknown'. Each of the four smoking categories shows a different age distribution with never, former and current smokers increasing in proportion of patients as age increases up to between 40 years and 60 years old, after which the proportions begin to drop. The unknown category has a very different distribution with a large spike between zero and 20 and a lower proportion in the older ages compared to the three other smoking categories.

As many modelling methods drop missing values, the desire to retain such a large number of observations is understandable. Alternatives to missing categories include forms of imputation.

Fortunately, decision trees can handle missing predictor values through the use of surrogate predictors so for the purposes of this analysis, we can convert the “Unknown” category to missing without further adjustment.

For consistency in categorical values, we will remove underscores, capitalize just the first word, and convert all categorical predictors and the outcome variable to factors. Without the awkward “Unknown” category, smoking_status now makes sense as an ordered factor . We will also convert bmi to a numeric variable and shorten the name avg_glucose_level to avg_glucose.

stroke_clean <- rename_with(stroke, tolower) %>% 
  rename(avg_glucose = avg_glucose_level) %>% 
  mutate(
    across(c(work_type, smoking_status), ~ str_to_sentence(gsub("_", " ", .x))),
    across(c(gender, ever_married, work_type, residence_type), factor),
    across(c(hypertension, heart_disease), factor, labels = c("No", "Yes")),
    smoking_status = factor(
      x = na_if(smoking_status, "Unknown"),
      levels = c("Never smoked", "Formerly smoked", "Smokes"),
      ordered = TRUE
    ),
    stroke = factor(stroke, levels = c(1, 0), labels = c("Stroke", "No stroke")),
    bmi = as.numeric(ifelse(bmi == "N/A", NA, bmi))
  ) 

Taking a look at our outcome variable, there appears to be a strong class imbalance - with roughly 5% of patients having had a stroke and 95% no stroke - which can lead to poor prediction in the underrepresented class. Advanced methods exist for dealing with this type of problem, for example up- or down-sampling of the training data; however, in this analysis we are focusing on the steps to implementing a decision tree so we will leave this issue for now.

summary(stroke_clean)
##        id           gender          age        hypertension heart_disease
##  Min.   :   67   Female:2994   Min.   : 0.08   No :4612     No :4834     
##  1st Qu.:17741   Male  :2115   1st Qu.:25.00   Yes: 498     Yes: 276     
##  Median :36932   Other :   1   Median :45.00                             
##  Mean   :36518                 Mean   :43.23                             
##  3rd Qu.:54682                 3rd Qu.:61.00                             
##  Max.   :72940                 Max.   :82.00                             
##                                                                          
##  ever_married         work_type    residence_type avg_glucose_level
##  No :1757     Children     : 687   Rural:2514     Min.   : 55.12   
##  Yes:3353     Govt job     : 657   Urban:2596     1st Qu.: 77.25   
##               Never worked :  22                  Median : 91.89   
##               Private      :2925                  Mean   :106.15   
##               Self-employed: 819                  3rd Qu.:114.09   
##                                                   Max.   :271.74   
##                                                                    
##       bmi                smoking_status       stroke    
##  Min.   :10.30   Never smoked   :1892   Stroke   : 249  
##  1st Qu.:23.50   Formerly smoked: 885   No stroke:4861  
##  Median :28.10   Smokes         : 789                   
##  Mean   :28.89   NA's           :1544                   
##  3rd Qu.:33.10                                          
##  Max.   :97.60                                          
##  NA's   :201   

As a final check, we will verify the number of unique IDs. If there are fewer unique IDs than total number of observations, then we would have data with a recurrent outcome. Such data would require careful consideration of regression model type as well as interpretation, since occurrence of a first stroke does not address stroke in people who have already had one, for example. Fortunately, there are 5110 unique IDs, the same number of rows in our dataset.

length(unique(stroke$id))
## [1] 5110

Step 2. Split the data into training and test sets

The rsample package contains various functions for splitting, and, as you might guess, resampling data. Here initial_split() will split the data with the default to allot three quarters of data to the training set and one quarter to the testing set.

We can add the strata argument to specify a column where we would like to absolutely maintain the same proportions between the two sets. This simple technique can go a little way toward addressing class imbalance, though is most helpful when the imbalance is not as strong as we have seen with our data.

The training() and testing() functions are used to extract the data corresponding to the observations randomized for training and testing.

# Set a seed for reproducibility
set.seed(123)

# Split the data, using the `strata` argument so that randomized allotment of 
# observation occurs within the levels of the specified column
stroke_split <- initial_split(stroke_clean, strata = stroke)

# Pull out the training and testing data
stroke_train <- training(stroke_split)
stroke_test <- testing(stroke_split)

Step 3. Define a model

With the tidymodels workflow, the process of fitting a model is broken up into a number of steps that can be accomplished in a few different ways depending on the complexity of the model and the amount of pre-processing required for the data.

For the purposes of this learning exercise, we will take the longest route, even though the chosen model will be relatively simple and the data will need little to no pre-processing.

The parsnip package has a dedicated function for each type of supported model. You can explore more models, associated packages, engines, and corresponding arguments . The function for a decision tree model is the aptly named decision_tree().

All parsnip model functions will contain the engine argument; this is the computational engine, i.e. the package, from which the functionality is pulled. Any given model will have a number of engines from which to choose. Other decision_tree() arguments include mode, which can be specified as either classification or regression , and three hyperparameters:

  • cost_complexity, a positive number for the cost assigned to terminal nodes
  • tree_depth, the maximum depth to which the tree can grow
  • min_n, the minimum number of data points required before a node can be split

See An introduction to decision tree theory - Pruning for more information on these hyperparameters.

Specify the model

When working with these modelling functions, the %>% pipe passes the model object as opposed to a data set as is normal with tidyverse functions. Here we can specify the engine and mode using set_engine() and set_mode() to demonstrate this piping behaviour.

We selected the rpart engine because it is the only engine provided by decision_tree() that supports cost complexity pruning. The use of classification mode reflects the binary categorical nature of the outcome, stroke.

To first allow the tree to grow without interference, we will set low barriers - i.e. lower values for cost_complexity and min_n, and higher values for tree_depth.

dc_tree_mod <-
  decision_tree(cost_complexity = 0, tree_depth = 20, min_n = 15) %>%
  set_engine("rpart") %>%
  set_mode("classification")
Create a recipe

A recipe describes a set of feature engineering steps that will be applied to the training data prior to model training and then to the test data prior to prediction (see Table 1 for examples). The recipe object can be piped to the different steps that will sequentially update the recipe. These step_* functions offer convenient pre-defined means of preparing data for modelling.

Step function Description
step_dummy() Creates dummy variables
step_logit() Applies a logit transformation
step_center() Normalizes a continuous variable

Table 1 Examples of feature engineering steps available with the recipes package.

The outcome and predictor variables can be specified with a formula or by supplying each variable name and its role. Our data still contain the id column which we do not want included as a predictor. We can simply update its role so that the model does not use this variable as a predictor. The two methods given below for defining our recipe are equivalent.

# Create a recipe using a formula and update the role of the `id` column
stroke_recipe <-
  recipe(stroke ~ ., data = stroke_train) %>% 
  update_role(id, new_role = "id var")

# Using `update_role()` to specify which variables are outcomes and which are 
# predictors produces the same recipe as above
stroke_recipe <-
  recipe(stroke_train) %>% 
  update_role(stroke, new_role = "outcome") %>%
  update_role(
    gender, 
    age, 
    hypertension, 
    heart_disease, 
    ever_married, 
    work_type, 
    residence_type, 
    avg_glucose_level, 
    bmi, 
    smoking_status, 
    new_role = "predictor"
  ) %>% 
  update_role(id, new_role = "id var")

summary(stroke_recipe)
## A tibble: 12 × 4
##    variable          type    role      source  
##    <chr>             <chr>   <chr>     <chr>   
##  1 id                numeric id var    original
##  2 gender            nominal predictor original
##  3 age               numeric predictor original
##  4 hypertension      nominal predictor original
##  5 heart_disease     nominal predictor original
##  6 ever_married      nominal predictor original
##  7 work_type         nominal predictor original
##  8 residence_type    nominal predictor original
##  9 avg_glucose_level numeric predictor original
## 10 bmi               numeric predictor original
## 11 smoking_status    nominal predictor original
## 12 stroke            nominal outcome   original

Finally, for convenience, the model and recipe can be bundled into a workflow object and this way passed together when training or testing data.

# Create a workflow object
stroke_wflow <-
  workflow() %>%
  add_model(dc_tree_mod) %>%
  add_recipe(stroke_recipe)

Step 4. Fit the model to the training data

With all that work, we’re now ready to model our data! We simply pass the workflow object to fit() which will train the model using the training data set.

We can visualize the results of the trained model with the rpart.plot package. The main plotting function takes an rpart object and builds a diagram of the full tree. We first need extract_fit_engine() to obtain the engine specific fit object required by rpart.plot(). Here we also use a few extra arguments to improve the look of the plot such as tweak which multiplies the label size.

# Fit the model to the training data
stroke_fit <- 
  stroke_wflow %>% 
  fit(data = stroke_train)

# Visualize the decision tree structure
pa_pal_2tone <- colorRampPalette(c("#DFBAD3", "#E6B8D4", "#BAD8F6", "#9FBBFC"))

stroke_fit %>% 
  extract_fit_engine() %>%
  rpart.plot(
    roundint = FALSE,
    box.palette = pa_pal_2tone(6),
    yes.text = "true",
    no.text = "false",
    tweak = 1.25
  )
Diagram of decision tree describing the unpruned decision tree model for predicting the outcome of a stroke. The tree has a depth of 10 splits and 31 terminal nodes. Stroke was predicted for 11 of the terminal nodes. The most recurrent predictors involved in primary decision splits are age, BMI and average glucose level.

Our tree has grown to a depth of 10 splits with 31 terminal nodes. All nodes are labelled with the most prevalent outcome, the proportion of observations with the “Stroke” outcome, and the percent of all observations in that node. The root and each intermediate node show the predictor and cutpoint used to split the data.

Let’s take a look at the first few splits. The root node decision is \(age \ge 68\). Following the \(false\) right branch, the next node is also split by age with a cutpoint of 56 or greater. The right branch of this split leads to a terminal node and the left leads to a node split by smoking status.

With 30% of smoking status values missing, there is a good chance that some of these observations ended up in this node. From this plot, we cannot tell whether this is true and, if so, which surrogate predictor was used.

To find this information we can pipe the rpart object to summary() instead of rpart.plot(). In this way we can see the detailed decisions made at each node including a ranking of primary and surrogate decisions. The entire output is too long to display here but a chunk is provided.

stroke_fit %>% 
  extract_fit_engine() %>%
  summary()
## Node number 6: 617 observations,    complexity param=0.005405405
##   predicted class=No stroke  expected loss=0.07455429  P(node) =0.1610125
##     class counts:    46   571
##    probabilities: 0.075 0.925 
##   left son=12 (328 obs) right son=13 (289 obs)
##   Primary splits:
##       smoking_status    splits as  RLL,         improve=1.8571830, (117 missing)
##       avg_glucose_level < 110.86  to the right, improve=1.6386020, (0 missing)
##       heart_disease     splits as  RL,          improve=1.5747690, (0 missing)
##       bmi               < 35.4    to the right, improve=1.1504340, (29 missing)
##       gender            splits as  RL-,         improve=0.6356169, (0 missing)
##   Surrogate splits:
##       gender            splits as  RL-,         agree=0.564, adj=0.035, (117 split)
##       age               < 57.5    to the right, agree=0.562, adj=0.031, (0 split)
##       avg_glucose_level < 80.26   to the right, agree=0.556, adj=0.018, (0 split)

At node 6, we can see a primary split on smoking_status with 117 observations missing a value. The top ranked surrogate split is gender with the levels of “Female”, “Male”, and “Other” assigned right, left and neither, respectively. Neither for “Other” because none are present in this node.

Interestingly, the root node’s right branch accounts for 83% of training observations and produces only five of the 31 terminal nodes, one of which predicts occurrence of a stroke. The left root node branch produces a much more sprawling sub-tree and uses more varied predictors in its subsequent splits. Given the size of the tree and the initial values we selected for our hyperparameters, we should suspect overfitting.

Step 5. Tune some hyperparameters

We will prune the model by tuning the hyperparameters to which we previously assigned values that permitted a larger tree. Higher cost complexity will make it more expensive for the model to retain additional branches in the tree and a smaller maximum tree depth will stop the tree from growing too deep.

We could also tune the minimum number of observations per node to a larger value; however, to have a reasonable computing time, we will stick with tuning just two hyperarameters for this analysis. Together these hyperparameters can help us avoid overfitting while still maximizing performance.

Create a tuning model

We will assign tune(), which is a placeholder function for argument values to be tuned, to the cost_complexity and tree_depth hyperparameters, and let min_n take on the rpart default value of 20.

Here we also show that the engine and mode arguments can be equivalently specified within the decision tree model as opposed to using the set_* functions as we did before.

We then create a new workflow object by updating the model in the workflow specified above. The recipe will not change.

# Create a new model and add tuning hyperparameters
dc_tree_mod_tune <-
  decision_tree(
    cost_complexity = tune(),
    tree_depth = tune(),
    engine = "rpart",
    mode = "classification"
  )

# Create a new workflow by updating the model in the original workflow
stroke_wflow_tune <-
  stroke_wflow %>% 
  update_model(dc_tree_mod_tune)
Prepare tuning values and cross-validation folds

We will perform ten fold cross-validation tuning of our model to determine the set of hyperparameter values expected to result in the best performance on the test data.

First we create a data set of tuning values to test in our cross-validation using grid_regular(). We could have manually specified a data frame with pairs of values for cost complexity and tree depth but the advantage of grid_regular() is that it automatically generates an appropriate range of values for each specified hyperparameter.

The levels argument determines the number of values to be tested. If a single integer, levels will apply to all specified hyperparameters. Otherwise, a vector the same length as the number of hyperparameters is required. We will look at 5 values for both cost complexity and tree depth, thus 25 models in total.

The vfold_cv() function will produce the “folds” of data for our cross-validation. The v in the function name refers to the number of folds, the default for which is 10. As with initial_split() we can specify the strata argument so that random sampling is stratified by some variable.

# Create a data frame of tuning values to model
dc_tree_grid <- grid_regular(
  cost_complexity(),
  tree_depth(),
  levels = 5
)

# Set a seed and create cross validation folds
set.seed(234)
stroke_folds <- vfold_cv(stroke_train, strata = stroke)
Tune the hyperparameters using cross-validation

This next part is a little to unpack, but we can do it!

To run the cross-validation tuning, we pass our tuning workflow object to tune_grid() which will compute performance metrics for each set of tuning hyperparameters by training the model on \(V - 1\) analysis folds and testing on the remaining assessment fold. Each fold acts as the assessment fold once, and mean performance metrics across the assessments are returned.

Which translates to a lot of computations; 250 in our case. Consequently, we might need more computing power than a standard sequential R process can provide (sequential meaning commands are executed in sequence).

tune_grid() can use the foreach package for parallel processing. However, we first have to make additional R processes available.

After determining the number of physical cores1 available with detectCores(), we use makePSOCKcluster() to create additional R processes identical to the main R process. To make these additional R processes available for use by the foreach package functions, we register them to the backend using registerDoParallel().

# Number of physical cores available
n_core <- detectCores(logical = FALSE)

# Create a cluster object and register as a backend
cl <- makePSOCKcluster(n_core - 1)
registerDoParallel(cl)

We can now run our cross-validation. This could take seconds to minutes depending on the computation and number of cores.

# Run the cross-validation tuning
stroke_cv <-
  stroke_wflow_tune %>%
  tune_grid(
    resamples = stroke_folds,
    grid = dc_tree_grid
  )

Once complete, stopCluster() tells the additional R processes to shut down, however they are still registered. It is good practice to un-register the additional processes by calling registerDoSEQ() which will reset to the sequential backend.

Many thanks to Precision Analytics' Senior Software Developer, Hugo Barnaby, for his invaluable clarifications on parallel processing in R!

# Stop the cluster and reset with empty sequential backend
stopCluster(cl)
registerDoSEQ()

Step 6. Select the best hyperparameter values and finalize the model

Review the cross-validation tuning results

The stroke_cv data frame contains a list column, .metrics that stores the tuning results. We can see these results using the function collect_metrics().

stroke_cv %>% 
  collect_metrics()
## # A tibble: 50 × 8
##    cost_complexity tree_depth .metric  .estimator  mean     n std_err .config              
##              <dbl>      <int> <chr>    <chr>      <dbl> <int>   <dbl> <chr>                
##  1    0.0000000001          1 accuracy binary     0.952    10 0.00306 Preprocessor1_Model01
##  2    0.0000000001          1 roc_auc  binary     0.5      10 0       Preprocessor1_Model01
##  3    0.0000000178          1 accuracy binary     0.952    10 0.00306 Preprocessor1_Model02
##  4    0.0000000178          1 roc_auc  binary     0.5      10 0       Preprocessor1_Model02
##  5    0.00000316            1 accuracy binary     0.952    10 0.00306 Preprocessor1_Model03
##  6    0.00000316            1 roc_auc  binary     0.5      10 0       Preprocessor1_Model03
##  7    0.000562              1 accuracy binary     0.952    10 0.00306 Preprocessor1_Model04
##  8    0.000562              1 roc_auc  binary     0.5      10 0       Preprocessor1_Model04
##  9    0.1                   1 accuracy binary     0.952    10 0.00306 Preprocessor1_Model05
## 10    0.1                   1 roc_auc  binary     0.5      10 0       Preprocessor1_Model05
## # … with 40 more rows

Trying to pull out the best set of values for cost complexity and tree depth according to both accuracy and ROC AUC (i.e. the area under the curve of the Receiver Operator Characteristic plot) is a bit difficult with just numbers, numbers, numbers to look at. So let’s get visual!

stroke_cv %>% 
  collect_metrics() %>%
  mutate(tree_depth = factor(tree_depth)) %>%
  ggplot(aes(x = cost_complexity, y = mean, colour = tree_depth)) +
  geom_line(size = 1.3, alpha = 0.6) +
  geom_point(size = 1.5) +
  facet_wrap(~ .metric, nrow = 1) +
  scale_x_log10(
    breaks = unique(dc_tree_grid$cost_complexity),
    labels = formatC(unique(dc_tree_grid$cost_complexity), format = "e", digits = 2)
  ) +
  scale_color_manual(values = pa_pal(5)) +
  theme_minimal() + 
  theme(legend.position = "bottom", panel.spacing = unit(1, "lines"))
A line graph with two panels describing mean performance metrics for hyperparameter tuning pairs of tree depth and cost complexity values. The left panel is titled 'accuracy' and the right panel is titled 'roc_auc'. The X-axes on both panels represent cost complexity with values ranging from 1 x 10^-10 to 1 x 10^-1. The Y-axis is shared between the panels and represents a mean value; mean accuracy in the left panel and mean ROC AUC in the right panel. There are five lines in each panel; one for each tree depth value of 1, 4, 8, 11, and 15. In the left panel, the five lines overlap with a constant mean accuracy of around 0.95 with no variation by cost complexity. In the right panel, lines for tree depth 8, 11, and 15 overlap with a constant mean ROC AUC around 0.76 up to cost complexity 5.6 x 10^-4, after which it drops to 0.5. Tree depth 4 has the same pattern with a lower mean ROC AUC of 0.7 dropping to 0.5. Tree depth of 1 has a constant mean ROC AUC of 0.5 across cross complexity values.

There are functions for displaying or selecting the “best” tuning values show_best() and select_best() which are entirely based on the estimated accuracy or the AUC values. We want to choose the set of hyperparameters to have smaller tree depth and higher cost complexity while maintaining a good performance.

Looking at the mean accuracy in the above plot shows no variation at all and will therefore be unhelpful in our decision making. Accuracy is a metric at risk of bias in the presence of a strong class imbalance so we will have to be careful how we interpret such a consistently high value.

We can see that mean AUC is maximized at tree depth of 8 and greater and for cost complexity values at or below 5.6 x 10-4. If we View() all the results from the above collect_metrics(), the 27th row will give us this set of hyperparameter tuning values. We can save this row and use it to update the hyperparameters in our workflow with the chosen values.

Finalize the model

Instead of using update_model() as we did last time, we will use finalize_workflow(). This function is designed to take a tibble or list of tuning values for hyperparameters and splice them into the workflow object.

# Select best model
best_stroke_tune_val <- 
  stroke_cv %>% 
  collect_metrics() %>% 
  slice(27)

# Splice in chosen hyperparameter tuning values
final_stroke_wflow <- 
  stroke_wflow_tune %>% 
  finalize_workflow(best_stroke_tune_val)

Step 7. Test and review the final model

Fit the final model on the training data and evaluate on the test data

A handy function, last_fit(), both fits the final model on the training data and evaluates its performance on the test data in one shot.

We can use the metrics argument to specify which performance metrics we wish to see. Since accuracy was suspiciously high and constant during the tuning process, this time let’s select the AUC, sensitivity, and specificity.

# Fit the final model and predict the test data set
final_stroke_fit <- 
  final_stroke_wflow %>%
  last_fit(stroke_split, metrics = metric_set(roc_auc, sens, spec))
Review the final model

Let’s take a look at the pruned tree. It now has a total of 21 terminal nodes with depth of 8 decision splits. Our pruning removed two layers of branches and ten terminal nodes from the tree.

# Visualize the final decision tree structure
final_stroke_fit %>%
  extract_fit_engine() %>%
  rpart.plot(
    roundint = FALSE,
    box.palette = pa_pal_2tone(6),
    yes.text = "true",
    no.text = "false",
    tweak = 1.25
  )
Diagram of decision tree describing the final pruned decision tree model for predicting the outcome of a stroke. The tree has a depth of 8 splits and 21 terminal nodes. Stroke was predicted for 7 of the terminal nodes. The most recurrent predictors involved in primary decision splits are age, BMI and average glucose level.

Looking at the decision splits, we can try to determine which predictors were the most important in the model. age and bmi appear most frequently in decision splits, but how do we quantify their contribution to the overall model? And how do we evaluate surrogate decision splits that aren’t displayed in the tree diagram?

The function vi() from the vip package can calculate the variable importance for each predictor. From section 3.4 of the rpart vignette (PDF, 286 KB) , the variable importance of an rpart model object is based on the sum of 1) the goodness of fit at each split in which a predictor was the primary variable and 2) the goodness of fit multiplied by the adjusted agreement for splits in which it served as a surrogate variable.

vi() can return either a value (default) or a rank (set rank = TRUE) for variable importance. A variable importance value is meant to be interpreted as the relative importance of variables in the model. Setting scale = TRUE will set the value of the most important variable to 100, and scale the importance values of the other variables relative to the most important.

final_stroke_fit %>%
  extract_fit_engine() %>%
  vi(scale = TRUE)
## # A tibble: 9 × 2
##   Variable       Importance
##   <chr>               <dbl>
## 1 age               100    
## 2 bmi                49.7  
## 3 avg_glucose        29.6  
## 4 work_type          20.5  
## 5 heart_disease      13.6  
## 6 smoking_status      6.70 
## 7 ever_married        4.49 
## 8 hypertension        3.48 
## 9 gender              0.649

We can see that age is the most important predictor in the model, which was fairly intuitive from the final tree diagram. However, now we know that bmi, as the second most important predictor, contributed half as much as age to the model’s overall goodness of fit. gender appears to have had a very weak contribution to the model as the lowest ranked contributor. And, interestingly, hypertension and smoking_status, which are both established risk factors for stroke , also had a fairly small contribution to the model relative to age.

A convenient function, vip(), produces a plot of variable importance directly from the model fit object that can allow us to quickly visualize the relative importance. However, it is also simple (and fun!) to recreate this plot with one data transformation and a few lines of ggplot2 code. We’ve added some colour and a theme change to our plot to keep our visualizations consistent.

final_stroke_fit %>%
  extract_fit_engine() %>%
  vi(scale = TRUE, decreasing = FALSE) %>% 
  mutate(Variable = factor(Variable, levels = Variable)) %>% 
  ggplot(aes(x = Importance, y = Variable)) + 
  geom_col(fill = pa_pal(1)) +
  theme_minimal() +
  theme(axis.title.y = element_blank())
A bar graph describing the variable importance of predictors in the final stroke prediction decision tree model. The X-axis represents relative importance with values ranging from zero to 100. The Y-axis represents the predictors in the final decision tree model and are listed in decreasing order of importance: age, BMI, average glucose level, work type, heart disease, smoking status, ever married, hypertension and gender. Age is the most important predictor so its importance value is scalled to 100; BMI has an importance value close to 50, and average glucose level has an imprtance close to 30. The next most important predictor has an importance around 20 and the following predictors have decreasing importance values down to less than one for gender.

Step 8. Examine model performance

Build a Receiver Operator Characteristic Curve

To build an ROC curve, we can use collect_predictions() to extract the predicted probabilities, predicted class as well as the true outcome value from our last fit object. We can pass the truth (our outcome variable) and the predicted probability of having a stroke to roc_curve() and then to ggplot() to visualize the ROC curve.

The ROC curve shows us the trade off between sensitivity and specificity at different predicted probability thresholds. Our curve displays a fairly typical shape and indicates that our model performed better than chance at predicting the outcome (dotted line).

# View the ROC curve
final_stroke_fit %>%
  collect_predictions() %>% 
  roc_curve(stroke, `.pred_Stroke`) %>% 
  ggplot(aes(x = 1 - specificity, y = sensitivity)) + 
  geom_line(colour = pa_pal(1), size = 1) + 
  geom_abline(linetype = "dotted", size = 0.75) + 
  coord_fixed() +
  theme_minimal()
A line graph representing the ROC (Receiver Operator Characteristic) curve. The X-axis represents 1 minus specifity and the Y-axis represent sensitvity. A dotted line at a 45 degree angle represents an AUC (area under the curve) of 0.5. The ROC curve peaks about midway between the top left corner and the dotted line with a sensitivity of around 0.75 and 1 minus specifity around 0.25.
Extract perfomance metrics

We can access the final performance metrics using collect_metrics() again. Our final model has an overall AUC of 75.3% but a sensitivity of only 7.8%, and a specificity of 98%. The large class imbalance in our outcome is one explanation for why we see such low sensitivity in this evaluation.

# Look at the performance metrics
final_stroke_fit %>%
  collect_metrics()
## # A tibble: 3 × 4
##   .metric .estimator .estimate .config             
##   <chr>   <chr>          <dbl> <chr>               
## 1 sens    binary        0.0781 Preprocessor1_Model1
## 2 spec    binary        0.980  Preprocessor1_Model1
## 3 roc_auc binary        0.753  Preprocessor1_Model1
Examine performance at different probability thresholds

The reported sensitivity and specificity are calculated based on the standard predicted probability threshold of 0.5. Since an AUC of 75.3% is not bad at all, we can use threshold_perf() on the predictions to examine the sensitivity and specificity values at different probability thresholds. The J-index (also known as Youden’s J statistic ) is another metric returned by threshold_perf() that can give clues about the trade off between sensitivity and specificity. It is simply calculated as \(sensitivity + specificity - 1\) and is meant to give equal weight to false positive and false negative predictions - i.e. overall misclassification.

Let’s use another plot to examine the results.

# Find the sensitivity, specificity and J-index values at different predicted
# probability thresholds
stroke_threshold <-
final_stroke_fit %>%
  collect_predictions() %>% 
  threshold_perf(stroke, .pred_Stroke, seq(0, 0.5, by = 0.01)) %>% 
  filter(.metric != "distance") 

ggplot(stroke_threshold) +
  geom_rect(aes(xmin = 0.03, xmax = 0.08, ymin = -0.05, ymax = 1), fill = "grey80") +
  geom_line(aes(x = .threshold, y = .estimate, color = .metric), 
    size = 1, 
    alpha = 0.8
  ) +
  theme_minimal() + 
  scale_colour_manual(values = pa_pal(3))
A line graph describing the tradeoff between sensitivity and specificity of the final decision tree model predictions at different probability thresholds. The X-axis represents the probability threshold values. The Y-axis represents the estimated metric value. There are three lines; one for each of the J-index, sensitivity, and specificity. A grey rectangle behind the lines shades the probability threshold values from 0.03 to 0.08 where the highest J-index values are found. The line for sensitivity starts at 1 when the probability threshold is zero and rapidly drops flattening out below 0.10. The specificity line beings with a value of zero and increases rapidly plateauing around 0.98. The sensitivity and specificity lines cross between probability threshold 0.03 and 0.04, within the grey shaded rectangle. The J-index line begins at zero, increases to a maximum of around 0.5 within the shaded rectangle and then raidly decreases again with a pattern similar to the sensitivity line.

The best trade off between sensitivity and specificity appear between probability thresholds of 0.03 and 0.08. Extremely low values! While the J-index is technically maximized at probability thresholds of 0.04 and 0.05, it varies very little within the highlighted window. A big jump in sensitivity from 40.6% to 70.3% is seen around the threshold of 0.08 which also maintains a decent specificity of 79.1%.

stroke_threshold %>% 
  filter(between(.threshold, 0.03, 0.09)) %>% 
  pivot_wider(names_from = .metric, values_from = .estimate)
## # A tibble: 7 × 5
##   .threshold .estimator  sens  spec j_index
##        <dbl> <chr>      <dbl> <dbl>   <dbl>
## 1       0.03 binary     0.781 0.697   0.478
## 2       0.04 binary     0.719 0.782   0.500
## 3       0.05 binary     0.719 0.782   0.500
## 4       0.06 binary     0.703 0.791   0.494
## 5       0.07 binary     0.703 0.791   0.494
## 6       0.08 binary     0.703 0.791   0.494
## 7       0.09 binary     0.406 0.906   0.312

Maximizing the J-index can point to a useful threshold; however, as always, there is room for our own interpretation based on statistical and subject area knowledge. For example, in the context of assessing a patient for potential stroke, we might be willing to tolerate more false positives in an effort to ensure we correctly identify as many stroke events as possible. However, one could easily imagine another context where the consequences of false positives need to be taken into account (for example, if a false positive leads to unnecessary and invasive procedures that carry their own risks). These types of considerations underscore how important it is for data science teams to understand the implications of their decision making and to work closely with key stakeholders during model development.

Now we’ve successfully implemented a decision tree analysis using tidymodels! Even though the results might not have been optimal, we still learned a lot.

Advantages of decision tree models

The final decision tree is straightforward and interpretable with respect to our predictors of interest. We can easily obtain a predicted outcome for a set of predictor values (i.e., patient characteristics) from this figure, without special tools or complex formulas. We also get a clear picture of the predictors used in the model and their role in predicting the outcome.

Decision tree analysis is especially useful in applications where stakeholders want to know how the model works, and whether associations between predictors and the outcome make sense (e.g., from a clinical or scientific point of view).

In our work, we know that the choice of analysis should depend on our client’s goals rather than strictly on the performance of each approach. While other approaches can achieve better prediction in some contexts, we still consider decision trees indispensable due to their ease of application and interpretation.

Coming up

Decision trees suffer from some disadvantages, namely lower predictive accuracy than other machine learning models and instability to small changes in the training data. In efforts to offset these issues, methods for aggregating many decision trees together have been developed.

In our next adventure into tree modelling we’ll learn about some ensemble tree models, including random forest, as well as the steps for model comparison, all using tidymodels.

Additional references

James, G., Witten, D., Hastie, T., & Tibshirani, R. (2013). An introduction to statistical learning: With applications in R. Springer Texts in Statistics. Springer Science+Business Media New York.

Kuhn, M., & Vaughan, D. (2021). Where does probably fit in? probably. Retrieved October 2021, from https://probably.tidymodels.org/articles/where-to-use.html

Kuhn, M., & Silge, J. (2021, October 21). Tidy modeling with R. Retrieved October 2021, from https://www.tmwr.org/ .

Milborrow, S. (2021, June 1). Plotting rpart trees with the rpart.plot package. Retrieved October 2021, from http://www.milbo.org/rpart-plot/prp.pdf .

RStudio. Get started. Tidymodels. Retrieved October 2021, from https://www.tidymodels.org/start/ .

Therneau, T., & Atkinson, E. (2018, April 11). An introduction to recursive partitioning using the RPART routines. The Comprehensive R Archive Network. Retrieved October 2021, from https://cran.r-project.org/web//packages/rpart/vignettes/longintro.pdf .


  1. A small side note on the number of cores used for parallelization. Notice that we used logical = FALSE when detecting the number of cores. This argument allows us to detect the number of physical cores as opposed to the number of hardware threads that can run concurrently on the physical cores (which is commonly greater than the number of physical cores on modern CPUs; see the Wikipedia entry on Multithreading for more details).

    When running more processes than cores, the benefits of parallelization are lost because the processes have to share the core to which they are assigned, meaning they will be intermittently paused and resumed.

    Also note that we used one less than the total number of physical cores when creating the additional R processes to allow the remaining core for operations outside of R. ↩︎

Katie Dunkley-Hickin

As a Data Scientist at Precision Analytics, I have had the good fortune to meld my background in epidemiology and biology with …