7  Techniques for Machine Learning Applications

Learning Objectives

  • How to manipulate data through feature engineering
  • Select the most suitable model for your data
  • Learn about machine learning algorithms

Selecting the most suitable machine learning model involves understanding the goals of the analysis, the nature of the data, and the statistical and machine learning methods that best suit the tasks. In Chapter 6, we learned about what machine learning models are, provided examples for building a model framework, and selected common metrics for model performance calibration and evaluation.

In this chapter, we focus on the strategies for selecting appropriate models by leveraging the strengths of different techniques, specifically for health metrics and for infectious diseases. We will explore various considerations involved in addressing potential biases, and discuss actions to prevent them.

7.1 Goals of the Analysis and Nature of Data

The identification of the primary goal of the analysis is fundamental. Whether it involves trend analysis, investigating the relationships between response and predictor variables, or strictly forecasting to predict future outcomes, the strategy for model selection varies accordingly.

Health Metrics Data:

  • Composite Measures: Health metrics like DALYs are composite measures that include both mortality and morbidity data, often requiring sophisticated regression models capable of handling continuous variables and multiple predictors. By examining the components of DALYs (e.g., Years of Life Lost (YLLs) and Years Lived with Disability (YLDs)), we can identify the key drivers such as mortality rates, disease prevalence, and risk factors.

  • Regression Models: Regression models, including linear regression, Ridge regression, and Lasso regression, are commonly used to handle these continuous variables and address challenges like correlation and multicollinearity with appropriate techniques such as regularization.

Infectious Disease Data:

  • Categorical and Continuous Data: Infectious disease data can be categorical (e.g., disease presence or absence) or continuous (e.g., incidence rates). Classification models are suitable for categorical outcomes, while regression models are appropriate for continuous data.

  • Disease Dynamics: Understanding the dynamics of infectious diseases, such as transmission rates, incubation periods, and immunity, informs the selection of models. Common models include compartmental models (e.g., SIR, SEIR) and agent-based models.

Common considerations for health metrics and infectious diseases data type:

  • Seasonality and Trends: The data may exhibit seasonality or trends, necessitating the use of time series analysis models like ARIMA or seasonal decomposition to capture these patterns.

  • Simulation Models: These models can predict the impact of interventions on DALYs and infectious diseases, estimating the effectiveness of different interventions and guiding policy decisions. Examples of these types of models are: SIR models, and Agent-based models. In addition, confidence intervals and sensitivity analyses help assess the uncertainty associated with these predictions.

  • Bayesian Models: These models can estimate parameters and make predictions based on prior knowledge and observed data, incorporating uncertainty and variability.

  • Predictive modelling: Such as decision trees, support vector machines (SVM), and LSTM (Long Short-Term Memory) neural networks, can predict disease outbreaks, identify high-risk populations, and optimize resource allocation.

7.2 Statistical and Machine Learning Methods

The choice of model depends on the type of data, the relationships between variables, and the goals of the analysis. Once we have these factors well identified, we are a step forward in restricting the range of applicable models.

The next step involves conducting a thorough exploratory data analysis (EDA). This initial exploration helps to uncover the underlying structure of the data, the relationships between variables, and the way the response variable—which may also be referred to as the outcome variable—depends on predictors. This phase is critical as it informs the necessity of subsequent data adjustments and transformations.

The importance of data preparation and exploratory data analysis in machine learning are the building blocks in the preparation of machine learning digestible data. Feature engineering is a technique that involves creating new features from existing ones based on domain knowledge or transformation of data to improve the model’s ability to discern patterns. For example, creating features like moving averages or differences between consecutive days can reveal trends and cycles that are not immediately apparent from raw data.

Another example is the standardization process, which is crucial when dealing with variables measured in different units. It involves rescaling the features so they have a mean of zero and a standard deviation of one. This process is particularly important when variables span several orders of magnitude; without standardization, a model might incorrectly interpret the scale of a feature as a proxy for importance.

Furthermore, the application of transformations, such as logarithmic scaling or the application of spline functions can help in managing skewed data or enhancing model ability to capture non-linear relationships, which result particularly useful in complex data modellings. In addition, tailored adjustments, and more sophisticated manipulations have been implemented over time to allow estimation of missing values in order to obtain customized, flexible, and more homogeneous data. For more information on feature engineering, see1 useful for effective machine learning strategy application, covering various techniques and appropriate use cases, focusing on practical understanding and implementation.

7.3 Model Selection Strategies

In developing predictive models for health metrics and infectious diseases, selecting the appropriate model is critical to ensure accurate and reliable forecasts. Here are outlined sample strategies employed in the model selection process, we introduce the Rabies dataset used for our discussion and demonstrate the selection of a suitable model for analyzing its impact. Rabies, although nearly 100% fatal once symptoms appear, presents a unique challenge due to the relative rarity of cases and limited availability of comprehensive data. This scarcity complicates efforts to model the disease accurately and develop effective public health strategies.

To address these challenges, we explore advanced modeling techniques that can enhance the robustness of our analyses despite data limitations, which involves evaluating multiple models based on their performance and selecting the best-fitting models to achieve the most accurate predictions.

7.4 Example 4: Rabies

The rabies dataset from the hmsidwR package contains information on death rates and disability-adjusted life years (DALYs) per 100,000 inhabitants due to rabies and all causes mortality in Asia and for the Global region from 1990 to 2019. Rabies (2) is a fatal viral infection, and it is also classified as an infectious disease that can infect all mammals causing acute encephalitis. Caused by the rabies virus, which belongs to the Lyssavirus genus, it is transmitted to humans through the bite of an infected animal such as bats, raccoons, skunks, foxes, and obviously dogs, which are the main source of human rabies deaths.3 Rabies defined as neglected tropical disease (NTD) predominantly affects already marginalized, poor and vulnerable populations. Although effective human vaccines and immunoglobulins exist for rabies, these are often not readily available or accessible to everyone.4

In this example we consider the number of DALYs per 100,000 inhabitants due to rabies in Asia and the Global region, as our response variable, the dataset is made available in the hmsidwR package. It is composed of 240 observations and 7 variables: measure, location, cause, year, val, upper, lower.

library(tidyverse)
hmsidwR::rabies %>%
  select(-upper, -lower) %>%
  head()
#>   measure location      cause year      val
#> 1  Deaths   Global     Rabies 1990    0.479
#> 2  Deaths   Global All causes 1990 1113.610
#> 3  Deaths     Asia     Rabies 1991    0.659
#> 4  Deaths     Asia All causes 1991 1189.045
#> 5  Deaths   Global All causes 1991 1105.832
#> 6  Deaths     Asia     Rabies 1992    0.648

Selecting only the cause == Rabies , the first thing to notice is that deaths rates and DALYs are on different units, rates and years respectively.

library(tidyverse)
rabies <- hmsidwR::rabies %>%
  select(-upper, -lower) %>%
  pivot_wider(names_from = measure, values_from = val) %>%
  filter(cause == "Rabies") %>%
  rename(dx_rabies = Deaths, dalys_rabies = DALYs) %>%
  select(-cause)

rabies %>% head()
#> # A tibble: 6 × 4
#>   location  year dx_rabies dalys_rabies
#>   <chr>    <int>     <dbl>        <dbl>
#> 1 Global    1990     0.479         29.7
#> 2 Asia      1991     0.659         38.3
#> 3 Asia      1992     0.648         37.7
#> 4 Global    1991     0.477         29.6
#> 5 Asia      1993     0.625         36.4
#> 6 Global    1992     0.47          29.2

It can be seen that the number of deaths due to rabies is much lower than the number of DALYs. This difference in scale can affect the model’s ability to learn from the data. To address this issue, we can scale and center the numeric variables to make them more comparable.

p1 <- rabies %>%
  ggplot(aes(x = year, group = location, color = location)) +
  geom_line(aes(y = dx_rabies),
            linewidth = 1) +
  geom_line(aes(y = dalys_rabies))

p2 <- rabies %>%
  # apply a scale transformation to the numeric variables
  mutate(across(where(is.double), scale)) %>%
  ggplot(aes(x = year, group = location, color = location)) +
  geom_line(aes(y = dx_rabies),
            linewidth = 1) +
  geom_line(aes(y = dalys_rabies))
DALYs due to Rabies - Not scaled and Scaled and centered
(a) Not scaled
DALYs due to Rabies - Not scaled and Scaled and centered
(b) Scaled and centered
Figure 7.1: Not Scaled and Scaled and Centered

Creating new features from existing ones provide additional predictive power. Then, combine the cause vector in a way to obtain two vectors for death rates due to rabies and all causes, scale and center the numeric variables to obtain homogeneous data to use in the model.

all_causes <- hmsidwR::rabies %>%
  select(-upper, -lower) %>%
  pivot_wider(names_from = measure, values_from = val) %>%
  filter(!cause == "Rabies") %>%
  rename(dx_allcauses = Deaths, dalys_allcauses = DALYs) %>%
  select(-cause)

dat <- rabies %>%
  full_join(all_causes, by = c("location", "year"))

dat %>% head()
#> # A tibble: 6 × 6
#>   location  year dx_rabies dalys_rabies dx_allcauses dalys_allcauses
#>   <chr>    <int>     <dbl>        <dbl>        <dbl>           <dbl>
#> 1 Global    1990     0.479         29.7        1114.          50060.
#> 2 Asia      1991     0.659         38.3        1189.          50022.
#> 3 Asia      1992     0.648         37.7        1171.          49099.
#> 4 Global    1991     0.477         29.6        1106.          49742.
#> 5 Asia      1993     0.625         36.4        1153.          48251.
#> 6 Global    1992     0.47          29.2        1097.          49223.

To be able to visualize the magnitude of difference between death rates and DALYs for both rabies and all causes, it is necessary to scale or standardize the data as shown above.

p3 <- dat %>%
  select(-year, -location) %>%
  scale() %>%
  cbind(dat %>% select(year, location)) %>%
  ggplot(aes(x = year, group = location, color = location)) +
  geom_line(aes(y = dx_rabies),
            linewidth = 1) +
  geom_line(aes(y = dx_allcauses))

p4 <- dat %>%
  select(-year, -location) %>%
  scale() %>%
  cbind(dat %>% select(year, location)) %>%
  ggplot(aes(x = year,
             group = location,
             color = location)) +
  geom_line(aes(y = dalys_rabies),
            linewidth = 1) +
  geom_line(aes(y = dalys_allcauses))
Deaths and DALYs due to Rabies and All Causes - Scaled and centered
(a) Deaths due to Rabies and All Causes
Deaths and DALYs due to Rabies and All Causes - Scaled and centered
(b) Dalys due to Rabies and All Causes
Figure 7.2: Scaled and centered

For this task, we will use the tidymodels meta-package, as it provides a consistent interface for modelling and machine learning tasks. In particular, we define and execute modelling workflows, to create tailored data pre-processing tasks on various modelling specification, and evaluate the performance using resampling techniques, to eventually select the best model. A more detailed explanation of the tidymodels framework can be found in the book.5

7.4.1 Training Data and Resampling

Splitting data into training and test allows the model to train a subsection of the data and then test the its performance on the remaining part of the data, the test set. In this case, we will use the initial_split() function to split the data into training and test sets. The proportion assigned to train can vary but it usually assigned to be 80%, also a stratification option can be set.

library(tidymodels)

set.seed(04042024)
split <- initial_split(dat, prop = 0.8)
training <- training(split)
test <- testing(split)

After that, it is important to create a set of folds, which means a set of subgroups of the original data by grouping following specific directions based on the type of resampling technique. Resampling techniques are used to evaluate the model’s performance and estimate its generalization error. There are various types of resampling techniques, it depends on the specific characteristics of your dataset, and the goals of your analysis. Some of the most common resampling techniques include:

  • k-Fold Cross-Validation for general model evaluation and hyperparameter tuning.
  • Bootstrap Resampling to estimate the variability of your model and for smaller datasets.
  • Time Series Cross-Validation for time-dependent data to preserve temporal structure.
  • Spatial Resampling for spatially correlated data to account for spatial dependencies.
  • Stratified Resampling when dealing with imbalanced datasets to ensure proper representation of all classes.

In this case, we will use k-Fold Cross-Validation to evaluate the model’s performance. The vfold_cv() function creates a set of folds for cross-validation, which is used to train and test the model on different subsets of the data.

set.seed(04052024)
folds <- vfold_cv(training, v = 10)

7.4.2 Data Preprocessing and Featuring Engineering

As already seen in the exploratory phase, preprocessing data is a crucial step in machine learning, this process can include techniques for handling missing values, standardization of the data, encoding categorical variables, and removing highly correlated variables.

In this case, we will use the recipes package to create a recipe, with a set of preprocessing steps. The recipe() function allows to define a model formula and use various step_<functions>()` for manipulating data. We are going to set up 3 recipes, the first is a basic one which include all variables and does not perform any data transformation.

rec <- recipe(dalys_rabies ~ ., data = training)

The second recipe includes some key steps to transform the data into a way specific models would be able to understand and learn from it. Models such as k-nearest neighbors, or support vector machines, that rely on distance metrics, can be sensitive to differences in feature scales.

For instance, non-standardized year data, can dominate the model’s decision-making process, leading to biased results. By scaling and centering the data, we ensure that all features contribute equally to the model’s predictions.

We can create more complex recipes with more steps, but for this example, we will use a step for encoding the location variable (Asia, Global) into a numeric vector, and a second step to normalize (or standardize) all predictors.

rec1 <- recipe(dalys_rabies ~ ., data = training) %>%
  # convert nominal variables to dummy variables
  step_dummy(all_nominal_predictors()) %>%
  # scale the numeric variables
  step_normalize(all_numeric_predictors())

Once the recipe is created, we can apply it to the data using the prep() function, which estimates the necessary parameters for the transformations and apply them to the data. Then, to check the results we can use the juice() function to extract the processed data.

rec1 %>%
  prep() %>%
  juice() %>%
  select(1, 2, 5) %>%
  head()
#> # A tibble: 6 × 3
#>     year dx_rabies dalys_rabies
#>    <dbl>     <dbl>        <dbl>
#> 1 -1.22     0.484          28.3
#> 2 -0.448    0.0219         23.9
#> 3 -0.781    1.22           32.3
#> 4  1.44    -1.40           11.1
#> 5  0.439   -0.139          20.4
#> 6  1.21    -1.19           11.7

Trained data can be also tested on new data, in this case we test them on the test`set with the bake() function.

rec1 %>%
  prep() %>%
  bake(new_data = test) %>%
  select(1, 2, 5) %>%
  head()
#> # A tibble: 6 × 3
#>     year dx_rabies dalys_rabies
#>    <dbl>     <dbl>        <dbl>
#> 1 -1.22      1.67          36.4
#> 2 -0.781     0.204         25.6
#> 3 -1.00      0.323         26.8
#> 4  0.550    -0.776         16.6
#> 5  0.993    -1.20          12.9
#> 6  0.106     0.421         25.2

DALYs often aggregate various health impacts, and can have highly skewed distributions. This skewness arises due to several factors: the presence of outliers, the nature of the health condition being measured, and the distribution of the data itself. To handle the skewness of the data, we can apply:

  • Log Transformation: log(DALYs+1)
  • Sqrt Transformation: \sqrt{DALYs}
  • Yeo-Johnson Transformation, a generalization of the Box-Cox transformation that can handle both positive and negative values: ((DALYs+1)^p-1)/p.

Let’s apply the Yeo-Johnson transformations to the response variable (dalys_rabies) and see how the density distribution changes with different values of \lambda. This is a step that can be tuned with a machine learning algorithm.

Log10, Yeo-Johnson p=-2, Yeo-Johnson p=2
(a) Log10
Log10, Yeo-Johnson p=-2, Yeo-Johnson p=2
(b) Yeo-Johnson p=-2
Log10, Yeo-Johnson p=-2, Yeo-Johnson p=2
(c) Log10
Figure 7.3: Response variable transformation

Let’s now create a third recipe with the step_YeoJohnson() function.

rec2 <- rec1 %>%
  # apply Yeo-Johnson transformation to the response variable
  step_YeoJohnson(dalys_rabies)

rec2 %>%
  prep() %>%
  juice() %>%
  select(1, 2, 5) %>%
  head()
#> # A tibble: 6 × 3
#>     year dx_rabies dalys_rabies
#>    <dbl>     <dbl>        <dbl>
#> 1 -1.22     0.484         12.6 
#> 2 -0.448    0.0219        11.2 
#> 3 -0.781    1.22          13.9 
#> 4  1.44    -1.40           6.36
#> 5  0.439   -0.139          9.97
#> 6  1.21    -1.19           6.63

7.4.3 Correlation, Multicollinearity and Overfitting

To be noted is that we haven’t applied any correlation selection step on this data. Filtering out highly correlated predictors, such as those with a correlation greater than 80% to avoid multicollinearity, would lead to exclude crucial variables. On the other hand, including all possible covariates in a model often yields implausible signs on covariates or unstable coefficients, as well as overfitting.6

When multiple predictors are correlated, but all are crucial for the analysis (e.g., deaths due to rabies, total deaths, and total DALYs for all causes), applying a correlation step that filters out correlated variables can be problematic. One way to overcome bias arising from it is using regularization techniques like Ridge Regression or Lasso Regression is often the best approach to handle multicollinearity without removing any predictors. Alternatively, PCA can reduce dimensionality while retaining most of the variance. These methods ensure all important predictors are included in the model without the adverse effects of multicollinearity.

7.4.4 Model Specification

The next step is to outline the model specification. There are various type of models that can be used. We start with a random forest. This choice is tipically done due to the algorithm’s features, which is able to create multiple bootstrap samples (random samples with replacement) from the original dataset. Each bootstrap sample is used to train a separate decision tree.

7.4.5 Model 1: Random Forest

Rabies death rates may exhibit complex relationships with predictor variables. Random forests are capable of capturing non-linear relationships between predictors and the target variable.

Also, it handles multicollinearity, missing data, provides variables importance and are an ensemble learning method, which means they combine the predictions of multiple individual decision trees to produce a more accurate and stable prediction.

In our simplified case this type of model will do random samples with replacement of data. In tidymodels we can select different types of engines, in the case of random forest we could use random forest, ranger, and others. The difference between these engines derives from the specific type of calculation used to make the estimation. The Ranger engine is notably faster than random forest, so let’s use that for this example.

rf_mod <- rand_forest(mtry = tune(),
                      trees = tune(),
                      min_n = tune(),
                      mode = "regression",
                      engine = "ranger")

wkf <- workflow(preprocessor = rec, 
                spec = rf_mod)

rf_res <- tune_grid(object = wkf,
                    resamples = folds,
                    grid = 5,
                    control = control_grid(save_pred = TRUE))

show_best(rf_res, metric = "rmse") %>%
  select(-n, -std_err)
#> # A tibble: 5 × 7
#>    mtry trees min_n .metric .estimator  mean .config             
#>   <int> <int> <int> <chr>   <chr>      <dbl> <chr>               
#> 1     5   840     9 rmse    standard   0.968 Preprocessor1_Model4
#> 2     3  1678    15 rmse    standard   1.38  Preprocessor1_Model2
#> 3     2  1568    18 rmse    standard   1.97  Preprocessor1_Model3
#> 4     2   610    29 rmse    standard   3.15  Preprocessor1_Model1
#> 5     4   369    38 rmse    standard   3.22  Preprocessor1_Model5
rf_res_tuned <- select_best(rf_res, metric = "rmse")

rf_res_tuned
#> # A tibble: 1 × 4
#>    mtry trees min_n .config             
#>   <int> <int> <int> <chr>               
#> 1     5   840     9 Preprocessor1_Model4
rf_fit <- wkf %>%
  finalize_workflow(select_best(rf_res,
                                metric = "rmse")) %>%
  fit(training)

rf_fit %>%
  predict(new_data = test) %>%
  bind_cols(test) %>%
  rmse(truth = dalys_rabies, estimate = .pred)
#> # A tibble: 1 × 3
#>   .metric .estimator .estimate
#>   <chr>   <chr>          <dbl>
#> 1 rmse    standard       0.479
Predictions vs. Truth and Predictions vs. Truth by Year
(a) Predictions vs. Truth
Predictions vs. Truth and Predictions vs. Truth by Year
(b) Predictions vs. Truth by Year
Figure 7.4: Predictions vs. Truth

7.4.6 Model 2: Generalized Linear Model (GLM)

Generalized Linear Models (GLMs) involve statistical estimation rather than the iterative parameter tuning, common in many machine learning techniques. However, adding a machine learning feature through parameter calibration can be done using techniques such as cross-validation and grid search to find the best model settings.

To introduce a machine learning feature with parameter calibration into our modelling of the rabies data, we can use a technique like cross-validation combined with a regularization method or an algorithm that supports parameter tuning. Here, we can employ a model from the glmnet package, which fits a generalized linear model via penalized maximum likelihood. The regularization path is computed for the lasso or elastic-net penalty at a grid of values for the regularization parameter lambda.

Adding Machine Learning Features with glmnet and Cross-Validation

For glmnet, we need to input matrices rather than data frames, and create matrices for the independent variables (predictors) and the dependent variable (response).

predictors <- model.matrix(dalys_rabies ~ .,
                           data = dat)[, -1] # Remove intercept
response <- dat$dalys_rabies

Use cross-validation to find the optimal lambda value, which controls the strength of the regularization:

# Set seed for reproducibility
set.seed(123)

# Fit the model with cross-validation
cv_model <- cv.glmnet(predictors,
                      response,
                      family = "gaussian")
cv_model
#> 
#> Call:  cv.glmnet(x = predictors, y = response, family = "gaussian") 
#> 
#> Measure: Mean-Squared Error 
#> 
#>     Lambda Index Measure       SE Nonzero
#> min 0.1651    43 0.07307 0.007813       2
#> 1se 0.1812    42 0.08032 0.008546       3
# Plot the lambda selection
plot(cv_model)
Lambda Selection
Figure 7.5: Cross-Validation Optimal Lambda

Extracting the best model, we can see that \lambda is 0.165.

# Get the best lambda value
best_lambda <- cv_model$lambda.min
print(paste("Best Lambda:", best_lambda))
#> [1] "Best Lambda: 0.16512173333762"

Then, fitting the final model with the selected best lambda, we can predict and evaluate the model.

final_model <- glmnet(predictors,
                      response,
                      family = "gaussian",
                      lambda = best_lambda)

# Predict using the final model
predictions <- predict(final_model,
                       s = best_lambda,
                       newx = predictors
                       )

# Calculate Mean Squared Error
rmse <- sqrt(mean((response - predictions)^2))
print(paste("Root Mean Squared Error:", rmse))
#> [1] "Root Mean Squared Error: 0.257014560701469"

By incorporating glmnet and using lambda selection via cross-validation, we introduce a machine learning feature—parameter calibration—into our analysis. This approach not only helps in minimizing overfitting but also enhances model performance by selecting the most effective regularization parameter. The cross-validation process used here is crucial for confirming that our model’s parameters are optimally tuned for the given data, embodying a key aspect of machine learning methodologies.

Predictions vs. Truth and Predictions vs. Truth by Year
(a) Predictions vs. Truth
Predictions vs. Truth and Predictions vs. Truth by Year
(b) Predictions vs. Truth by Year
Figure 7.6: Predictions vs. Truth

7.4.7 Testing Multiple Models

In the example above, we used two models to predict DALYs due to rabies, a random forest with tidymodels and a generalized linear model with glmnet with a Root Mean-Square Error of 0.448 and 0.257 respectively. The Random Forest model has a higher RMSE, which means it has a higher prediction error compared to the GLM model. However, we haven’t applied any of the preprocessing steps, and there are many other models that could be used to predict DALYs such as:

  1. Support Vector Machines (SVM): SVMs are a powerful machine learning algorithm that can be used for both classification and regression tasks. They work by finding the hyperplane that best separates the data into different classes or groups.

  2. Extreme Gradient Boosting (XGBoost): Known for its high performance in various prediction tasks, XGBoost can handle missing values and is effective for large datasets.

  3. K-Nearest Neighbors (KNN) models are a type of instance-based learning algorithm that stores all available cases and classifies new cases based on a similarity measure.

  4. Long Short-Term Memory (LSTM) Networks: For temporal or sequential health data, LSTM networks can capture dependencies over time, making them suitable for time-series prediction of disease progression and outcomes.

Each of these models has its own strengths and weaknesses, and the choice of model will depend on the specific characteristics of the data and the goals of the analysis. By testing multiple models and comparing their performance, we can identify the best model for the given data and task.

Let’s use the parsnip package and the workflow_set() function to fit a set of models to the rabies data. We will fit a Support Vector Machine (SVM), and a K-Nearest Neighbors (KNN) model to the data and compare their performance.

linear_reg_spec <- 
  linear_reg(penalty = tune(),
             mixture = tune()) %>%
  set_engine("glmnet")

svm_p_spec <-
  svm_poly(cost = tune(),
           degree = tune()) %>%
  set_engine("kernlab") %>%
  set_mode("regression")

knn_spec <-
  nearest_neighbor(neighbors = tune(),
                   dist_power = tune(),
                   weight_func = tune()) %>%
  set_engine("kknn") %>%
  set_mode("regression")
library(rules)
library(baguette)
# Combine workflows into a workflow set
workflow_set <- workflow_set(preproc = list(scaled = rec1,
                                            yeo_johnson = rec2),
                             models = list(linear_reg = linear_reg_spec,
                                           svm = svm_p_spec,
                                           knn = knn_spec))

grid_ctrl <-control_grid(save_pred = TRUE,
                         parallel_over = "everything",
                         save_workflow = TRUE)

# Fit and evaluate the models with hyperparameter tuning
grid_results <- workflow_set %>%
  workflow_map(seed = 1503,
               resamples = folds,
               grid = 5,
               control = grid_ctrl)
# Show the results
grid_results %>%
  collect_metrics() %>%
  arrange(mean) %>%
  select(1, 5, 7, 9) %>%
  head()
#> # A tibble: 6 × 4
#>   wflow_id               .metric  mean std_err
#>   <chr>                  <chr>   <dbl>   <dbl>
#> 1 yeo_johnson_linear_reg rmse    0.132  0.0196
#> 2 yeo_johnson_linear_reg rmse    0.135  0.0198
#> 3 yeo_johnson_linear_reg rmse    0.136  0.0202
#> 4 yeo_johnson_linear_reg rmse    0.136  0.0210
#> 5 yeo_johnson_linear_reg rmse    0.136  0.0185
#> 6 yeo_johnson_svm        rmse    0.186  0.0281
autoplot(grid_results,
         rank_metric = "rmse",
         metric = "rmse",
         select_best = TRUE) +
  geom_text(aes(y = mean - 0.1, 
                label = wflow_id),
            angle = 90, hjust = 1, size = 4) +
  lims(y = c(-1, 0.9)) +
  theme(legend.position = "none")
Model Performance
Figure 7.7: Model Performance

7.5 Summary

The integration of machine learning techniques into public health data analysis can significantly enhance the predictive power and robustness of models. By leveraging the capabilities of machine learning algorithms, we can extract valuable insights from complex health data, enabling more informed decision-making and policy formulation in public health contexts. The examples provided in this chapter illustrate the application of machine learning techniques to health metrics data, demonstrating the importance of feature engineering, model selection, and parameter calibration in enhancing the predictive accuracy and relevance of models. By following best practices in machine learning, public health researchers and practitioners can harness the power of data-driven insights to address critical health challenges and improve population health outcomes.

Best Practices for Machine Learning in Public Health:

  • Conduct exploratory data analysis to understand the underlying structure of the data and relationships between variables.
  • Apply feature engineering techniques to create new variables and enhance the model’s predictive power.
  • Select machine learning models that are contextually appropriate and robust for public health data analysis. Such as Random Forest, Generalized Linear Models, and others.
  • Use parameter calibration techniques such as cross-validation, regularization, monte carlo, and grid search to optimize model performance.

and grid search to optimize model performance. - Evaluate model performance using appropriate metrics and visualization tools to assess predictive accuracy and relevance.

By following these best practices, public health researchers and practitioners can leverage the full potential of machine learning techniques to address complex health challenges and improve population health outcomes. The integration of machine learning methodologies into public health data analysis represents a significant opportunity to advance the field of public health and enhance our understanding of health metrics and disease dynamics.


  1. Brandon Butcher and Brian J. Smith, The American Statistician 74, no. 3 (July 2020): 308–9, doi:10.1080/00031305.2020.1790217.↩︎

  2. CDC, “About Rabies,” May 14, 2024, https://www.cdc.gov/rabies/about/index.html.↩︎

  3. Katie Hampson et al., “Estimating the Global Burden of Endemic Canine Rabies,” PLOS Neglected Tropical Diseases 9, no. 4 (April 2015): e0003709, doi:10.1371/journal.pntd.0003709.↩︎

  4. “Rabies,” n.d., https://www.who.int/news-room/fact-sheets/detail/rabies.↩︎

  5. Max Kuhn Silge and Julia, Tidy Modeling with r, n.d., https://www.tmwr.org/.↩︎

  6. Kyle J. Foreman et al., “Modeling Causes of Death: An Integrated Approach Using CODEm,” Population Health Metrics 10, no. 1 (January 2012): 1, doi:10.1186/1478-7954-10-1.↩︎