The purpose of the markdown is to demonstrate how to use dials to manage and tune parsnip models. See Dials, Tune, and Parsnip: Tidymodels’ Way to Create and Tune Model Parameters for more information.

Import

library(tidymodels)
library(corrr)
library(moments)
data("penguins")
head(penguins)
## # A tibble: 6 x 7
##   species island bill_length_mm bill_depth_mm flipper_length_… body_mass_g sex  
##   <fct>   <fct>           <dbl>         <dbl>            <int>       <int> <fct>
## 1 Adelie  Torge…           39.1          18.7              181        3750 male 
## 2 Adelie  Torge…           39.5          17.4              186        3800 fema…
## 3 Adelie  Torge…           40.3          18                195        3250 fema…
## 4 Adelie  Torge…           NA            NA                 NA          NA <NA> 
## 5 Adelie  Torge…           36.7          19.3              193        3450 fema…
## 6 Adelie  Torge…           39.3          20.6              190        3650 male
penguins %>% glimpse()
## Rows: 344
## Columns: 7
## $ species           <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Ade…
## $ island            <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgers…
## $ bill_length_mm    <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39.2, 34.1,…
## $ bill_depth_mm     <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19.6, 18.1,…
## $ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193, 190, 18…
## $ body_mass_g       <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 4675, 3475,…
## $ sex               <fct> male, female, female, NA, female, male, female, mal…
penguins %>% skimr::skim()
Data summary
Name Piped data
Number of rows 344
Number of columns 7
_______________________
Column type frequency:
factor 3
numeric 4
________________________
Group variables None

Variable type: factor

skim_variable n_missing complete_rate ordered n_unique top_counts
species 0 1.00 FALSE 3 Ade: 152, Gen: 124, Chi: 68
island 0 1.00 FALSE 3 Bis: 168, Dre: 124, Tor: 52
sex 11 0.97 FALSE 2 mal: 168, fem: 165

Variable type: numeric

skim_variable n_missing complete_rate mean sd p0 p25 p50 p75 p100 hist
bill_length_mm 2 0.99 43.92 5.46 32.1 39.23 44.45 48.5 59.6 ▃▇▇▆▁
bill_depth_mm 2 0.99 17.15 1.97 13.1 15.60 17.30 18.7 21.5 ▅▅▇▇▂
flipper_length_mm 2 0.99 200.92 14.06 172.0 190.00 197.00 213.0 231.0 ▂▇▃▅▂
body_mass_g 2 0.99 4201.75 801.95 2700.0 3550.00 4050.00 4750.0 6300.0 ▃▇▆▃▂
penguins <- penguins %>%
  select(-island) %>%
  drop_na()

Split Data

set.seed(300)
split <- initial_split(penguins, prop = 0.75)
penguins_train <- training(split)
penguins_test <- testing(split)
folds_5 <- vfold_cv(penguins_train, v = 5, repeats = 2)

Show Random Forest Engine

show_engines("rand_forest")
## # A tibble: 6 x 2
##   engine       mode          
##   <chr>        <chr>         
## 1 ranger       classification
## 2 ranger       regression    
## 3 randomForest classification
## 4 randomForest regression    
## 5 spark        classification
## 6 spark        regression
show_model_info("rand_forest")
## Information for `rand_forest`
##  modes: unknown, classification, regression 
## 
##  engines: 
##    classification: randomForest, ranger, spark
##    regression:     randomForest, ranger, spark
## 
##  arguments: 
##    ranger:       
##       mtry  --> mtry
##       trees --> num.trees
##       min_n --> min.node.size
##    randomForest: 
##       mtry  --> mtry
##       trees --> ntree
##       min_n --> nodesize
##    spark:        
##       mtry  --> feature_subset_strategy
##       trees --> num_trees
##       min_n --> min_instances_per_node
## 
##  fit modules:
##          engine           mode
##          ranger classification
##          ranger     regression
##    randomForest classification
##    randomForest     regression
##           spark classification
##           spark     regression
## 
##  prediction modules:
##              mode       engine                    methods
##    classification randomForest           class, prob, raw
##    classification       ranger class, conf_int, prob, raw
##    classification        spark                class, prob
##        regression randomForest               numeric, raw
##        regression       ranger     conf_int, numeric, raw
##        regression        spark                    numeric

Tune Parameters

1. Use default parameters in parsnip

rf_spec is a random forest model specification created with parsnip. I do not specify values for any parameters, resulting in using the default values. As always, I then fit the model on the training data. The default parameters are printed.

rf_spec <- 
  rand_forest(mode = "regression") %>%
  set_engine("randomForest")

model_default <-
  rf_spec %>%
  fit(body_mass_g~., data = penguins_train)

model_default
## parsnip model object
## 
## Fit time:  146ms 
## 
## Call:
##  randomForest(x = maybe_data_frame(x), y = y) 
##                Type of random forest: regression
##                      Number of trees: 500
## No. of variables tried at each split: 1
## 
##           Mean of squared residuals: 89926.71
##                     % Var explained: 85.51
model_default %>% 
  predict(penguins_test) %>% 
  bind_cols(penguins_test) %>% 
  metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     317.   
## 2 rsq     standard       0.877
## 3 mae     standard     242.

2. Use tune to tune parsnip model

2.1. Manually provide values

# Update model specification
rf_spec <-
  rf_spec %>%
  update(mtry = tune(), trees = tune())

# Create workflow
rf_workflow <-
  workflow() %>%
  add_variables(outcomes = body_mass_g, predictors = everything()) %>%
  add_model(rf_spec)

# Put parameters in a grid
manual_grid <-
  expand.grid(mtry = c(1, 3, 5), trees = c(500, 1000, 2000))

# Tune
set.seed(300)
manual_tune <-
  rf_workflow %>%
  tune_grid(resamples = folds_5, grid = manual_grid)

# Show all results
collect_metrics(manual_tune)
## # A tibble: 18 x 8
##     mtry trees .metric .estimator    mean     n std_err .config             
##    <dbl> <dbl> <chr>   <chr>        <dbl> <int>   <dbl> <chr>               
##  1     1   500 rmse    standard   306.       10  9.84   Preprocessor1_Model1
##  2     1   500 rsq     standard     0.858    10  0.0146 Preprocessor1_Model1
##  3     3   500 rmse    standard   301.       10 14.3    Preprocessor1_Model2
##  4     3   500 rsq     standard     0.854    10  0.0178 Preprocessor1_Model2
##  5     5   500 rmse    standard   303.       10 14.5    Preprocessor1_Model3
##  6     5   500 rsq     standard     0.852    10  0.0180 Preprocessor1_Model3
##  7     1  1000 rmse    standard   305.       10  9.82   Preprocessor1_Model4
##  8     1  1000 rsq     standard     0.859    10  0.0143 Preprocessor1_Model4
##  9     3  1000 rmse    standard   300.       10 14.5    Preprocessor1_Model5
## 10     3  1000 rsq     standard     0.854    10  0.0180 Preprocessor1_Model5
## 11     5  1000 rmse    standard   304.       10 14.5    Preprocessor1_Model6
## 12     5  1000 rsq     standard     0.851    10  0.0180 Preprocessor1_Model6
## 13     1  2000 rmse    standard   306.       10 10.1    Preprocessor1_Model7
## 14     1  2000 rsq     standard     0.858    10  0.0144 Preprocessor1_Model7
## 15     3  2000 rmse    standard   300.       10 14.5    Preprocessor1_Model8
## 16     3  2000 rsq     standard     0.854    10  0.0179 Preprocessor1_Model8
## 17     5  2000 rmse    standard   304.       10 14.7    Preprocessor1_Model9
## 18     5  2000 rsq     standard     0.851    10  0.0181 Preprocessor1_Model9

# Show the best one
show_best(manual_tune, n = 1)
## # A tibble: 1 x 8
##    mtry trees .metric .estimator  mean     n std_err .config             
##   <dbl> <dbl> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1     3  2000 rmse    standard    300.    10    14.5 Preprocessor1_Model8

# Finalise
manual_final <-
  finalize_workflow(rf_workflow, select_best(manual_tune)) %>%
  fit(penguins_train)

# Predict on testing data
manual_final %>% 
  predict(penguins_test) %>% 
  bind_cols(penguins_test) %>% 
  metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     296.   
## 2 rsq     standard       0.881
## 3 mae     standard     238.

2.2. Specify grid size for automatic generation

set.seed(300)
random_tune <-
  rf_workflow %>%
  tune_grid(
    resamples = folds_5, grid = 5
  )
collect_metrics(random_tune)
## # A tibble: 10 x 8
##     mtry trees .metric .estimator    mean     n std_err .config             
##    <int> <int> <chr>   <chr>        <dbl> <int>   <dbl> <chr>               
##  1     5  1879 rmse    standard   304.       10 14.5    Preprocessor1_Model1
##  2     5  1879 rsq     standard     0.851    10  0.0181 Preprocessor1_Model1
##  3     2   799 rmse    standard   298.       10 13.6    Preprocessor1_Model2
##  4     2   799 rsq     standard     0.857    10  0.0171 Preprocessor1_Model2
##  5     3  1263 rmse    standard   300.       10 14.5    Preprocessor1_Model3
##  6     3  1263 rsq     standard     0.854    10  0.0179 Preprocessor1_Model3
##  7     2   812 rmse    standard   297.       10 13.7    Preprocessor1_Model4
##  8     2   812 rsq     standard     0.858    10  0.0171 Preprocessor1_Model4
##  9     4   193 rmse    standard   302.       10 14.9    Preprocessor1_Model5
## 10     4   193 rsq     standard     0.852    10  0.0182 Preprocessor1_Model5

show_best(random_tune, n = 1)
## # A tibble: 1 x 8
##    mtry trees .metric .estimator  mean     n std_err .config             
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1     2   812 rmse    standard    297.    10    13.7 Preprocessor1_Model4

random_final <-
  finalize_workflow(rf_workflow, select_best(random_tune)) %>%
  fit(penguins_train)

random_final %>% 
  predict(penguins_test) %>% 
  bind_cols(penguins_test) %>% 
  metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     295.   
## 2 rsq     standard       0.883
## 3 mae     standard     233.

3. Create parameter values with dials

mtry()
## # Randomly Selected Predictors (quantitative)
## Range: [1, ?]

mtry() %>% range_get()
## $lower
## [1] 1
## 
## $upper
## unknown()

mtry() %>% range_set(c(1, ncol(penguins_train) - 1))
## # Randomly Selected Predictors (quantitative)
## Range: [1, 5]

mtry(c(1, ncol(penguins_train) - 1))
## # Randomly Selected Predictors (quantitative)
## Range: [1, 5]

cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer:  log-10 
## Range (transformed scale): [-10, -1]

trees()
## # Trees (quantitative)
## Range: [1, 2000]

set.seed(300)
trees() %>% value_seq(n = 4)
## [1]    1  667 1333 2000

trees() %>% value_seq(n = 5)
## [1]    1  500 1000 1500 2000

trees() %>% value_seq(n = 10)
##  [1]    1  223  445  667  889 1111 1333 1555 1777 2000

set.seed(300)
trees() %>% value_sample(n = 4)
## [1]  590  874 1602  985

trees() %>% value_sample(n = 5)
## [1] 1692  789  553 1980 1875

trees() %>% value_sample(n = 10)
##  [1] 1705  272  461  780 1383 1868 1107  812  460  901

set.seed(300)
dials_regular <- grid_regular(
  mtry(c(1, ncol(penguins_train) - 1)),
  trees(),
  levels = 3
)
dials_regular
## # A tibble: 9 x 2
##    mtry trees
##   <int> <int>
## 1     1     1
## 2     3     1
## 3     5     1
## 4     1  1000
## 5     3  1000
## 6     5  1000
## 7     1  2000
## 8     3  2000
## 9     5  2000

set.seed(300)
dials_random <- grid_random(
  mtry(c(1, ncol(penguins_train) - 1)),
  trees(),
  size = 6
)
dials_random
## # A tibble: 6 x 2
##    mtry trees
##   <int> <int>
## 1     2  1980
## 2     2  1875
## 3     1  1705
## 4     4   272
## 5     5   461
## 6     1   780

dials_regular_tune <-
  rf_workflow %>%
  tune_grid(
    resamples = folds_5, grid = dials_regular
  )
collect_metrics(dials_regular_tune)
## # A tibble: 18 x 8
##     mtry trees .metric .estimator    mean     n std_err .config             
##    <int> <int> <chr>   <chr>        <dbl> <int>   <dbl> <chr>               
##  1     1     1 rmse    standard   410.       10 27.8    Preprocessor1_Model1
##  2     1     1 rsq     standard     0.736    10  0.0300 Preprocessor1_Model1
##  3     3     1 rmse    standard   408.       10 16.9    Preprocessor1_Model2
##  4     3     1 rsq     standard     0.744    10  0.0271 Preprocessor1_Model2
##  5     5     1 rmse    standard   404.       10 19.6    Preprocessor1_Model3
##  6     5     1 rsq     standard     0.764    10  0.0237 Preprocessor1_Model3
##  7     1  1000 rmse    standard   305.       10 10.0    Preprocessor1_Model4
##  8     1  1000 rsq     standard     0.859    10  0.0143 Preprocessor1_Model4
##  9     3  1000 rmse    standard   300.       10 14.4    Preprocessor1_Model5
## 10     3  1000 rsq     standard     0.854    10  0.0178 Preprocessor1_Model5
## 11     5  1000 rmse    standard   304.       10 14.5    Preprocessor1_Model6
## 12     5  1000 rsq     standard     0.851    10  0.0181 Preprocessor1_Model6
## 13     1  2000 rmse    standard   305.       10 10.1    Preprocessor1_Model7
## 14     1  2000 rsq     standard     0.858    10  0.0147 Preprocessor1_Model7
## 15     3  2000 rmse    standard   300.       10 14.5    Preprocessor1_Model8
## 16     3  2000 rsq     standard     0.854    10  0.0180 Preprocessor1_Model8
## 17     5  2000 rmse    standard   304.       10 14.7    Preprocessor1_Model9
## 18     5  2000 rsq     standard     0.851    10  0.0182 Preprocessor1_Model9
show_best(dials_regular_tune, n = 1)
## # A tibble: 1 x 8
##    mtry trees .metric .estimator  mean     n std_err .config             
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1     3  1000 rmse    standard    300.    10    14.4 Preprocessor1_Model5

dials_regular_final <-
  finalize_workflow(rf_workflow, select_best(dials_regular_tune)) %>%
  fit(penguins_train)

dials_regular_final %>% 
  predict(penguins_test) %>% 
  bind_cols(penguins_test) %>% 
  metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     296.   
## 2 rsq     standard       0.881
## 3 mae     standard     237.

dials_random_tune <-
  rf_workflow %>%
  tune_grid(
    resamples = folds_5, grid = dials_random
  )
collect_metrics(dials_random_tune)
## # A tibble: 12 x 8
##     mtry trees .metric .estimator    mean     n std_err .config             
##    <int> <int> <chr>   <chr>        <dbl> <int>   <dbl> <chr>               
##  1     2  1980 rmse    standard   298.       10 13.7    Preprocessor1_Model1
##  2     2  1980 rsq     standard     0.857    10  0.0174 Preprocessor1_Model1
##  3     2  1875 rmse    standard   297.       10 13.6    Preprocessor1_Model2
##  4     2  1875 rsq     standard     0.858    10  0.0171 Preprocessor1_Model2
##  5     1  1705 rmse    standard   307.       10 10.3    Preprocessor1_Model3
##  6     1  1705 rsq     standard     0.857    10  0.0151 Preprocessor1_Model3
##  7     4   272 rmse    standard   303.       10 14.5    Preprocessor1_Model4
##  8     4   272 rsq     standard     0.851    10  0.0181 Preprocessor1_Model4
##  9     5   461 rmse    standard   304.       10 14.8    Preprocessor1_Model5
## 10     5   461 rsq     standard     0.851    10  0.0181 Preprocessor1_Model5
## 11     1   780 rmse    standard   306.       10 10.1    Preprocessor1_Model6
## 12     1   780 rsq     standard     0.858    10  0.0146 Preprocessor1_Model6
show_best(dials_random_tune, n = 1)
## # A tibble: 1 x 8
##    mtry trees .metric .estimator  mean     n std_err .config             
##   <int> <int> <chr>   <chr>      <dbl> <int>   <dbl> <chr>               
## 1     2  1875 rmse    standard    297.    10    13.6 Preprocessor1_Model2

dials_random_final <-
  finalize_workflow(rf_workflow, select_best(dials_random_tune)) %>%
  fit(penguins_train)

dials_random_final %>% 
  predict(penguins_test) %>% 
  bind_cols(penguins_test) %>% 
  metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rmse    standard     296.   
## 2 rsq     standard       0.882
## 3 mae     standard     234.

Session Info

sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS  10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] randomForest_4.6-14 vctrs_0.3.6         rlang_0.4.10       
##  [4] moments_0.14        corrr_0.4.3         yardstick_0.0.7    
##  [7] workflows_0.2.1     tune_0.1.2          tidyr_1.1.2        
## [10] tibble_3.0.4        rsample_0.0.8       recipes_0.1.15     
## [13] purrr_0.3.4         parsnip_0.1.4       modeldata_0.1.0    
## [16] infer_0.5.3         ggplot2_3.3.3       dplyr_1.0.2        
## [19] dials_0.0.9         scales_1.1.1        broom_0.7.3        
## [22] tidymodels_0.1.2   
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.5         lubridate_1.7.9.2  lattice_0.20-41    listenv_0.8.0     
##  [5] class_7.3-17       foreach_1.5.1      assertthat_0.2.1   digest_0.6.27     
##  [9] ipred_0.9-9        parallelly_1.23.0  plyr_1.8.6         R6_2.5.0          
## [13] backports_1.2.1    evaluate_0.14      pillar_1.4.7       rstudioapi_0.13   
## [17] DiceDesign_1.8-1   furrr_0.2.1        rpart_4.1-15       Matrix_1.3-2      
## [21] rmarkdown_2.6      splines_4.0.2      gower_0.2.2        stringr_1.4.0     
## [25] munsell_0.5.0      compiler_4.0.2     xfun_0.20          pkgconfig_2.0.3   
## [29] globals_0.14.0     htmltools_0.5.0    nnet_7.3-14        tidyselect_1.1.0  
## [33] prodlim_2019.11.13 codetools_0.2-18   GPfit_1.0-8        fansi_0.4.1       
## [37] future_1.21.0      crayon_1.3.4       withr_2.3.0        MASS_7.3-53       
## [41] grid_4.0.2         gtable_0.3.0       lifecycle_0.2.0    magrittr_2.0.1    
## [45] pROC_1.16.2        cli_2.2.0          stringi_1.5.3      timeDate_3043.102 
## [49] ellipsis_0.3.1     lhs_1.1.1          generics_0.1.0     lava_1.6.8.1      
## [53] iterators_1.0.13   tools_4.0.2        glue_1.4.2         parallel_4.0.2    
## [57] survival_3.2-7     yaml_2.2.1         colorspace_2.0-0   knitr_1.30