The purpose of the markdown is to demonstrate how to use dials
to manage and tune parsnip
models. See Dials, Tune, and Parsnip: Tidymodels’ Way to Create and Tune Model Parameters for more information.
library(tidymodels)
library(corrr)
library(moments)
data("penguins")
head(penguins)
## # A tibble: 6 x 7
## species island bill_length_mm bill_depth_mm flipper_length_… body_mass_g sex
## <fct> <fct> <dbl> <dbl> <int> <int> <fct>
## 1 Adelie Torge… 39.1 18.7 181 3750 male
## 2 Adelie Torge… 39.5 17.4 186 3800 fema…
## 3 Adelie Torge… 40.3 18 195 3250 fema…
## 4 Adelie Torge… NA NA NA NA <NA>
## 5 Adelie Torge… 36.7 19.3 193 3450 fema…
## 6 Adelie Torge… 39.3 20.6 190 3650 male
penguins %>% glimpse()
## Rows: 344
## Columns: 7
## $ species <fct> Adelie, Adelie, Adelie, Adelie, Adelie, Adelie, Ade…
## $ island <fct> Torgersen, Torgersen, Torgersen, Torgersen, Torgers…
## $ bill_length_mm <dbl> 39.1, 39.5, 40.3, NA, 36.7, 39.3, 38.9, 39.2, 34.1,…
## $ bill_depth_mm <dbl> 18.7, 17.4, 18.0, NA, 19.3, 20.6, 17.8, 19.6, 18.1,…
## $ flipper_length_mm <int> 181, 186, 195, NA, 193, 190, 181, 195, 193, 190, 18…
## $ body_mass_g <int> 3750, 3800, 3250, NA, 3450, 3650, 3625, 4675, 3475,…
## $ sex <fct> male, female, female, NA, female, male, female, mal…
penguins %>% skimr::skim()
Name | Piped data |
Number of rows | 344 |
Number of columns | 7 |
_______________________ | |
Column type frequency: | |
factor | 3 |
numeric | 4 |
________________________ | |
Group variables | None |
Variable type: factor
skim_variable | n_missing | complete_rate | ordered | n_unique | top_counts |
---|---|---|---|---|---|
species | 0 | 1.00 | FALSE | 3 | Ade: 152, Gen: 124, Chi: 68 |
island | 0 | 1.00 | FALSE | 3 | Bis: 168, Dre: 124, Tor: 52 |
sex | 11 | 0.97 | FALSE | 2 | mal: 168, fem: 165 |
Variable type: numeric
skim_variable | n_missing | complete_rate | mean | sd | p0 | p25 | p50 | p75 | p100 | hist |
---|---|---|---|---|---|---|---|---|---|---|
bill_length_mm | 2 | 0.99 | 43.92 | 5.46 | 32.1 | 39.23 | 44.45 | 48.5 | 59.6 | ▃▇▇▆▁ |
bill_depth_mm | 2 | 0.99 | 17.15 | 1.97 | 13.1 | 15.60 | 17.30 | 18.7 | 21.5 | ▅▅▇▇▂ |
flipper_length_mm | 2 | 0.99 | 200.92 | 14.06 | 172.0 | 190.00 | 197.00 | 213.0 | 231.0 | ▂▇▃▅▂ |
body_mass_g | 2 | 0.99 | 4201.75 | 801.95 | 2700.0 | 3550.00 | 4050.00 | 4750.0 | 6300.0 | ▃▇▆▃▂ |
penguins <- penguins %>%
select(-island) %>%
drop_na()
set.seed(300)
split <- initial_split(penguins, prop = 0.75)
penguins_train <- training(split)
penguins_test <- testing(split)
folds_5 <- vfold_cv(penguins_train, v = 5, repeats = 2)
show_engines("rand_forest")
## # A tibble: 6 x 2
## engine mode
## <chr> <chr>
## 1 ranger classification
## 2 ranger regression
## 3 randomForest classification
## 4 randomForest regression
## 5 spark classification
## 6 spark regression
show_model_info("rand_forest")
## Information for `rand_forest`
## modes: unknown, classification, regression
##
## engines:
## classification: randomForest, ranger, spark
## regression: randomForest, ranger, spark
##
## arguments:
## ranger:
## mtry --> mtry
## trees --> num.trees
## min_n --> min.node.size
## randomForest:
## mtry --> mtry
## trees --> ntree
## min_n --> nodesize
## spark:
## mtry --> feature_subset_strategy
## trees --> num_trees
## min_n --> min_instances_per_node
##
## fit modules:
## engine mode
## ranger classification
## ranger regression
## randomForest classification
## randomForest regression
## spark classification
## spark regression
##
## prediction modules:
## mode engine methods
## classification randomForest class, prob, raw
## classification ranger class, conf_int, prob, raw
## classification spark class, prob
## regression randomForest numeric, raw
## regression ranger conf_int, numeric, raw
## regression spark numeric
parsnip
rf_spec
is a random forest model specification created with parsnip
. I do not specify values for any parameters, resulting in using the default values. As always, I then fit the model on the training data. The default parameters are printed.
rf_spec <-
rand_forest(mode = "regression") %>%
set_engine("randomForest")
model_default <-
rf_spec %>%
fit(body_mass_g~., data = penguins_train)
model_default
## parsnip model object
##
## Fit time: 146ms
##
## Call:
## randomForest(x = maybe_data_frame(x), y = y)
## Type of random forest: regression
## Number of trees: 500
## No. of variables tried at each split: 1
##
## Mean of squared residuals: 89926.71
## % Var explained: 85.51
model_default %>%
predict(penguins_test) %>%
bind_cols(penguins_test) %>%
metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 317.
## 2 rsq standard 0.877
## 3 mae standard 242.
tune
to tune parsnip
model# Update model specification
rf_spec <-
rf_spec %>%
update(mtry = tune(), trees = tune())
# Create workflow
rf_workflow <-
workflow() %>%
add_variables(outcomes = body_mass_g, predictors = everything()) %>%
add_model(rf_spec)
# Put parameters in a grid
manual_grid <-
expand.grid(mtry = c(1, 3, 5), trees = c(500, 1000, 2000))
# Tune
set.seed(300)
manual_tune <-
rf_workflow %>%
tune_grid(resamples = folds_5, grid = manual_grid)
# Show all results
collect_metrics(manual_tune)
## # A tibble: 18 x 8
## mtry trees .metric .estimator mean n std_err .config
## <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 500 rmse standard 306. 10 9.84 Preprocessor1_Model1
## 2 1 500 rsq standard 0.858 10 0.0146 Preprocessor1_Model1
## 3 3 500 rmse standard 301. 10 14.3 Preprocessor1_Model2
## 4 3 500 rsq standard 0.854 10 0.0178 Preprocessor1_Model2
## 5 5 500 rmse standard 303. 10 14.5 Preprocessor1_Model3
## 6 5 500 rsq standard 0.852 10 0.0180 Preprocessor1_Model3
## 7 1 1000 rmse standard 305. 10 9.82 Preprocessor1_Model4
## 8 1 1000 rsq standard 0.859 10 0.0143 Preprocessor1_Model4
## 9 3 1000 rmse standard 300. 10 14.5 Preprocessor1_Model5
## 10 3 1000 rsq standard 0.854 10 0.0180 Preprocessor1_Model5
## 11 5 1000 rmse standard 304. 10 14.5 Preprocessor1_Model6
## 12 5 1000 rsq standard 0.851 10 0.0180 Preprocessor1_Model6
## 13 1 2000 rmse standard 306. 10 10.1 Preprocessor1_Model7
## 14 1 2000 rsq standard 0.858 10 0.0144 Preprocessor1_Model7
## 15 3 2000 rmse standard 300. 10 14.5 Preprocessor1_Model8
## 16 3 2000 rsq standard 0.854 10 0.0179 Preprocessor1_Model8
## 17 5 2000 rmse standard 304. 10 14.7 Preprocessor1_Model9
## 18 5 2000 rsq standard 0.851 10 0.0181 Preprocessor1_Model9
# Show the best one
show_best(manual_tune, n = 1)
## # A tibble: 1 x 8
## mtry trees .metric .estimator mean n std_err .config
## <dbl> <dbl> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 2000 rmse standard 300. 10 14.5 Preprocessor1_Model8
# Finalise
manual_final <-
finalize_workflow(rf_workflow, select_best(manual_tune)) %>%
fit(penguins_train)
# Predict on testing data
manual_final %>%
predict(penguins_test) %>%
bind_cols(penguins_test) %>%
metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 296.
## 2 rsq standard 0.881
## 3 mae standard 238.
set.seed(300)
random_tune <-
rf_workflow %>%
tune_grid(
resamples = folds_5, grid = 5
)
collect_metrics(random_tune)
## # A tibble: 10 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 5 1879 rmse standard 304. 10 14.5 Preprocessor1_Model1
## 2 5 1879 rsq standard 0.851 10 0.0181 Preprocessor1_Model1
## 3 2 799 rmse standard 298. 10 13.6 Preprocessor1_Model2
## 4 2 799 rsq standard 0.857 10 0.0171 Preprocessor1_Model2
## 5 3 1263 rmse standard 300. 10 14.5 Preprocessor1_Model3
## 6 3 1263 rsq standard 0.854 10 0.0179 Preprocessor1_Model3
## 7 2 812 rmse standard 297. 10 13.7 Preprocessor1_Model4
## 8 2 812 rsq standard 0.858 10 0.0171 Preprocessor1_Model4
## 9 4 193 rmse standard 302. 10 14.9 Preprocessor1_Model5
## 10 4 193 rsq standard 0.852 10 0.0182 Preprocessor1_Model5
show_best(random_tune, n = 1)
## # A tibble: 1 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 812 rmse standard 297. 10 13.7 Preprocessor1_Model4
random_final <-
finalize_workflow(rf_workflow, select_best(random_tune)) %>%
fit(penguins_train)
random_final %>%
predict(penguins_test) %>%
bind_cols(penguins_test) %>%
metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 295.
## 2 rsq standard 0.883
## 3 mae standard 233.
dials
mtry()
## # Randomly Selected Predictors (quantitative)
## Range: [1, ?]
mtry() %>% range_get()
## $lower
## [1] 1
##
## $upper
## unknown()
mtry() %>% range_set(c(1, ncol(penguins_train) - 1))
## # Randomly Selected Predictors (quantitative)
## Range: [1, 5]
mtry(c(1, ncol(penguins_train) - 1))
## # Randomly Selected Predictors (quantitative)
## Range: [1, 5]
cost_complexity()
## Cost-Complexity Parameter (quantitative)
## Transformer: log-10
## Range (transformed scale): [-10, -1]
trees()
## # Trees (quantitative)
## Range: [1, 2000]
set.seed(300)
trees() %>% value_seq(n = 4)
## [1] 1 667 1333 2000
trees() %>% value_seq(n = 5)
## [1] 1 500 1000 1500 2000
trees() %>% value_seq(n = 10)
## [1] 1 223 445 667 889 1111 1333 1555 1777 2000
set.seed(300)
trees() %>% value_sample(n = 4)
## [1] 590 874 1602 985
trees() %>% value_sample(n = 5)
## [1] 1692 789 553 1980 1875
trees() %>% value_sample(n = 10)
## [1] 1705 272 461 780 1383 1868 1107 812 460 901
set.seed(300)
dials_regular <- grid_regular(
mtry(c(1, ncol(penguins_train) - 1)),
trees(),
levels = 3
)
dials_regular
## # A tibble: 9 x 2
## mtry trees
## <int> <int>
## 1 1 1
## 2 3 1
## 3 5 1
## 4 1 1000
## 5 3 1000
## 6 5 1000
## 7 1 2000
## 8 3 2000
## 9 5 2000
set.seed(300)
dials_random <- grid_random(
mtry(c(1, ncol(penguins_train) - 1)),
trees(),
size = 6
)
dials_random
## # A tibble: 6 x 2
## mtry trees
## <int> <int>
## 1 2 1980
## 2 2 1875
## 3 1 1705
## 4 4 272
## 5 5 461
## 6 1 780
dials_regular_tune <-
rf_workflow %>%
tune_grid(
resamples = folds_5, grid = dials_regular
)
collect_metrics(dials_regular_tune)
## # A tibble: 18 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 1 1 rmse standard 410. 10 27.8 Preprocessor1_Model1
## 2 1 1 rsq standard 0.736 10 0.0300 Preprocessor1_Model1
## 3 3 1 rmse standard 408. 10 16.9 Preprocessor1_Model2
## 4 3 1 rsq standard 0.744 10 0.0271 Preprocessor1_Model2
## 5 5 1 rmse standard 404. 10 19.6 Preprocessor1_Model3
## 6 5 1 rsq standard 0.764 10 0.0237 Preprocessor1_Model3
## 7 1 1000 rmse standard 305. 10 10.0 Preprocessor1_Model4
## 8 1 1000 rsq standard 0.859 10 0.0143 Preprocessor1_Model4
## 9 3 1000 rmse standard 300. 10 14.4 Preprocessor1_Model5
## 10 3 1000 rsq standard 0.854 10 0.0178 Preprocessor1_Model5
## 11 5 1000 rmse standard 304. 10 14.5 Preprocessor1_Model6
## 12 5 1000 rsq standard 0.851 10 0.0181 Preprocessor1_Model6
## 13 1 2000 rmse standard 305. 10 10.1 Preprocessor1_Model7
## 14 1 2000 rsq standard 0.858 10 0.0147 Preprocessor1_Model7
## 15 3 2000 rmse standard 300. 10 14.5 Preprocessor1_Model8
## 16 3 2000 rsq standard 0.854 10 0.0180 Preprocessor1_Model8
## 17 5 2000 rmse standard 304. 10 14.7 Preprocessor1_Model9
## 18 5 2000 rsq standard 0.851 10 0.0182 Preprocessor1_Model9
show_best(dials_regular_tune, n = 1)
## # A tibble: 1 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 3 1000 rmse standard 300. 10 14.4 Preprocessor1_Model5
dials_regular_final <-
finalize_workflow(rf_workflow, select_best(dials_regular_tune)) %>%
fit(penguins_train)
dials_regular_final %>%
predict(penguins_test) %>%
bind_cols(penguins_test) %>%
metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 296.
## 2 rsq standard 0.881
## 3 mae standard 237.
dials_random_tune <-
rf_workflow %>%
tune_grid(
resamples = folds_5, grid = dials_random
)
collect_metrics(dials_random_tune)
## # A tibble: 12 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 1980 rmse standard 298. 10 13.7 Preprocessor1_Model1
## 2 2 1980 rsq standard 0.857 10 0.0174 Preprocessor1_Model1
## 3 2 1875 rmse standard 297. 10 13.6 Preprocessor1_Model2
## 4 2 1875 rsq standard 0.858 10 0.0171 Preprocessor1_Model2
## 5 1 1705 rmse standard 307. 10 10.3 Preprocessor1_Model3
## 6 1 1705 rsq standard 0.857 10 0.0151 Preprocessor1_Model3
## 7 4 272 rmse standard 303. 10 14.5 Preprocessor1_Model4
## 8 4 272 rsq standard 0.851 10 0.0181 Preprocessor1_Model4
## 9 5 461 rmse standard 304. 10 14.8 Preprocessor1_Model5
## 10 5 461 rsq standard 0.851 10 0.0181 Preprocessor1_Model5
## 11 1 780 rmse standard 306. 10 10.1 Preprocessor1_Model6
## 12 1 780 rsq standard 0.858 10 0.0146 Preprocessor1_Model6
show_best(dials_random_tune, n = 1)
## # A tibble: 1 x 8
## mtry trees .metric .estimator mean n std_err .config
## <int> <int> <chr> <chr> <dbl> <int> <dbl> <chr>
## 1 2 1875 rmse standard 297. 10 13.6 Preprocessor1_Model2
dials_random_final <-
finalize_workflow(rf_workflow, select_best(dials_random_tune)) %>%
fit(penguins_train)
dials_random_final %>%
predict(penguins_test) %>%
bind_cols(penguins_test) %>%
metrics(body_mass_g, .pred)
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 296.
## 2 rsq standard 0.882
## 3 mae standard 234.
sessionInfo()
## R version 4.0.2 (2020-06-22)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_AU.UTF-8/en_AU.UTF-8/en_AU.UTF-8/C/en_AU.UTF-8/en_AU.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] randomForest_4.6-14 vctrs_0.3.6 rlang_0.4.10
## [4] moments_0.14 corrr_0.4.3 yardstick_0.0.7
## [7] workflows_0.2.1 tune_0.1.2 tidyr_1.1.2
## [10] tibble_3.0.4 rsample_0.0.8 recipes_0.1.15
## [13] purrr_0.3.4 parsnip_0.1.4 modeldata_0.1.0
## [16] infer_0.5.3 ggplot2_3.3.3 dplyr_1.0.2
## [19] dials_0.0.9 scales_1.1.1 broom_0.7.3
## [22] tidymodels_0.1.2
##
## loaded via a namespace (and not attached):
## [1] Rcpp_1.0.5 lubridate_1.7.9.2 lattice_0.20-41 listenv_0.8.0
## [5] class_7.3-17 foreach_1.5.1 assertthat_0.2.1 digest_0.6.27
## [9] ipred_0.9-9 parallelly_1.23.0 plyr_1.8.6 R6_2.5.0
## [13] backports_1.2.1 evaluate_0.14 pillar_1.4.7 rstudioapi_0.13
## [17] DiceDesign_1.8-1 furrr_0.2.1 rpart_4.1-15 Matrix_1.3-2
## [21] rmarkdown_2.6 splines_4.0.2 gower_0.2.2 stringr_1.4.0
## [25] munsell_0.5.0 compiler_4.0.2 xfun_0.20 pkgconfig_2.0.3
## [29] globals_0.14.0 htmltools_0.5.0 nnet_7.3-14 tidyselect_1.1.0
## [33] prodlim_2019.11.13 codetools_0.2-18 GPfit_1.0-8 fansi_0.4.1
## [37] future_1.21.0 crayon_1.3.4 withr_2.3.0 MASS_7.3-53
## [41] grid_4.0.2 gtable_0.3.0 lifecycle_0.2.0 magrittr_2.0.1
## [45] pROC_1.16.2 cli_2.2.0 stringi_1.5.3 timeDate_3043.102
## [49] ellipsis_0.3.1 lhs_1.1.1 generics_0.1.0 lava_1.6.8.1
## [53] iterators_1.0.13 tools_4.0.2 glue_1.4.2 parallel_4.0.2
## [57] survival_3.2-7 yaml_2.2.1 colorspace_2.0-0 knitr_1.30