Chapter 27 SML with Tidymodels
Hello! Today, we’ll be learning about machine learning for regression tasks (i.e., when your variables are continuous, rather than categorical). Recall that supervised machine learning is split into two parts: classification tasks (for categorical variables) and regression tasks (for continuous variables).
Let’s start by installing the tidymodels
package and the tidyverse
package. Like tidyverse
, tidymodels
is a mega-package of sorts: it contains a set of packages that work together (specifically, rsample
, recipes
, parsnip
and yardstick
).
library(tidyverse)
library(tidymodels)
options(scipen = 999999)
#install.packages("rlang")
#if you run into any issues, you may need to update your packages using the update.packages() function.
#if any specific packages give you trouble (rlang did for me), uninstall and reinstall the package.
#you can uninstall packages using the remove.packages() function, or you can go to the Package section of your console and manually remove it from the User Library.
Now, let’s import our data. For this tutorial, we’ll use tweets_academia, which we have used previously. But instead of focusing on the text, let’s focus now on the retweetability of a message. Due to the large size of the dataset (a large dataset takes longer to model), we’ll focus on a sample of that data.
To replicate this specific analysis, you will need to set.seed()
to make sure you are sampling the same tweets that I will (a different sample is fine, but the results may differ from mine).
set.seed(1337)
academic_tweets <- read_csv("data/rtweet_academictwitter_2021.csv") #import data
academic_tweets_sample <- academic_tweets[sample(1:nrow(academic_tweets), 1000),] %>% #samples 500 tweets
select(user_id, favorite_count, retweet_count, followers_count, friends_count, statuses_count, verified) %>% #let's focus on a few variables
mutate(verified = as.factor(verified))
#followers = people that follow an account
#friends = people that an account follows
You’ll notice that in this slimmed down dataset, the data I’ve selected are: * user_id = the id of the user * favorite_count = the number of likes a tweet had at the time of the collection * retweet_count = the number of retweets that the data had at the time of the collection * followers_count = the number of followers the user had at the time of the collection * friends_count = the number of accounts that the user followed at the time of the collection * status_count = the number of tweets posted by the user (includes retweets) * verified = whether the account has a verified blue check or not
## [1] "user_id" "favorite_count" "retweet_count" "followers_count"
## [5] "friends_count" "statuses_count" "verified"
Let’s take a look at this data using functions like str()
and glimpse()
.
## tibble [1,000 x 7] (S3: tbl_df/tbl/data.frame)
## $ user_id : num [1:1000] 515940853 4484050293 2225913155 4068052403 1131356596804694016 ...
## $ favorite_count : num [1:1000] 0 5 34 2 0 0 0 2 0 7 ...
## $ retweet_count : num [1:1000] 37 2 6 3 50 180 33 0 29 0 ...
## $ followers_count: num [1:1000] 228 488 2086 185 785 ...
## $ friends_count : num [1:1000] 888 280 1744 144 1877 ...
## $ statuses_count : num [1:1000] 39217 470 3522 1923 688 ...
## $ verified : Factor w/ 2 levels "FALSE","TRUE": 1 1 1 1 1 1 1 1 1 1 ...
## Rows: 1,000
## Columns: 7
## $ user_id <dbl> 515940853, 4484050293, 2225913155, 4068052403, 1131356~
## $ favorite_count <dbl> 0, 5, 34, 2, 0, 0, 0, 2, 0, 7, 0, 0, 0, 0, 0, 0, 0, 0,~
## $ retweet_count <dbl> 37, 2, 6, 3, 50, 180, 33, 0, 29, 0, 1, 4, 3, 394, 161,~
## $ followers_count <dbl> 228, 488, 2086, 185, 785, 31, 857, 1039, 1427, 3824, 1~
## $ friends_count <dbl> 888, 280, 1744, 144, 1877, 65, 725, 505, 1264, 3771, 8~
## $ statuses_count <dbl> 39217, 470, 3522, 1923, 688, 20, 3270, 5287, 25062, 33~
## $ verified <fct> FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE, FALSE~
While there are a couple things we could model with this data, we will focus today on the retweetability of a message. Thus, retweet_count will be our DV (outcome) and the other variables (all of one of these are numeric) can constitute our independent variables.
To create our model, we will run this data through the tidymodel
framework. We begin first with the sampling strategy; for this, we will use the rsample
package
27.1 Sampling, rsample
As with any supervised machine learning approach, you will likely want to split your data into a test set and a training set. Recall that we will use the training set to actually construct the model. Then, we will use the test set to evaluate the model. In tidymodels
, we can split the data using rsample
, a package you may already be familiar with.
The key function we will use to do this is called initial_split()
. This function takes at least two arguments: the dataset (in our case academic_tweets_sample
) and the proportion of the data you would like to split (this could be a fraction or a decimal, but cannot be more than 1 or less than 0). The number itself represents the proportion used in the training set (so if, like below, you set the proportion to 0.8
, 80% will be in the training set and 20% would be in the test set)
#library(rsample)
academic_split <- initial_split(academic_tweets_sample, prop = 0.8) #80/20
academic_split #what does this look like?
## <Training/Testing/Total>
## <800/200/1000>
Notice how the object academic_split
now is a summary of the data itself. If you’re interested in extracting this data, you can use the training()
and testing()
functions to extract the information.
academic_training <- training(academic_split)
academic_training %>% head(4) #look at the first 4 rows
## # A tibble: 4 x 7
## user_id favorite_count retweet_count followers_count friends~1 statu~2 verif~3
## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <fct>
## 1 3.27e 9 0 5 12103 2591 58119 FALSE
## 2 1.89e 7 0 10 699 1264 1893 FALSE
## 3 1.03e 8 0 9 310 952 8149 FALSE
## 4 1.35e18 0 5 135 540 546 FALSE
## # ... with abbreviated variable names 1: friends_count, 2: statuses_count,
## # 3: verified
## [1] 800
## [1] 200
As we discussed in class, there are many strategies that you can use to sample your data. One particularly popular one is the V-fold cross validation, which you can apply using the vfold_cv()
function. Learn more about vfold_ct()
here. For now, though, we’ll use this simple split.
27.2 Pre-processing, recipes
Like caret
, tidymodels
is primarily interested in helping users manage the varying hyperparameters for each model. However, they take slightly different approaches. Whereas caret
relies on a tuning grid (refer to last week’s tutorial for more), tidymodels
relies on two separate packages, one that helps handle data variations (called recipes
) and one that helps with specifying the model (parsnip
). When we fit a model, we use the data wrangled from recipes
with the model information provided from parsnip
. One advantage of this, compared to caret
, is control: with caret
, you may need to re-wrangle data for use across different machine learning algorithms. But with tidymodels
, this is done using recipes
.
At the heart of the recipes
package is the recipe()
function, which takes two arguments: the formula (dv ~ iv) and the dataset. Keep in mind that, with a machine learning algorithm, the dv (the variable on the left) refers to the outcome variable (the one you want to predict) and the iv (the variable[s] on the right) are the predictor variables (the ones you think are useful for predicting the dv).
In this example, we’ll focus on building a model to predict retweet count based on the other variables.
academic_recipe_data <- recipe(retweet_count ~ favorite_count + followers_count + statuses_count + verified,
data = academic_training) #notice here that I use the training set
#academic_recipe_data #if you run this function, you'll notice that it simply lists the # of outcome and predictor variables
#summary(academic_recipe_data) #use summary() to learn more
27.2.1 step_ functions
Now that we know how to put together a basic recipe, let’s start adding some step to wrangle this data further! First, since we plan to apply a LASSO regression (which requires normalized data points), we will use the step_normalize()
function to normalize the numeric variables. Second, since we have a categorical variable (verified
), we will need to construct a dummy variable. Let’s do so now.
academic_recipe <- recipe(retweet_count ~ favorite_count + followers_count + statuses_count + verified,
data = academic_training) %>%
step_normalize(retweet_count, favorite_count, followers_count, statuses_count) %>%
step_dummy(verified) #you can also use all_nominal() to focus on all the nominal/categorical data
academic_recipe <- prep(academic_recipe_data)
Another interesting option may be to use the all_outcomes()
function to focus on the outcome (DV) variables and the all_predictors()
function to focus on the IVs (this useful if you have a lot of IVs). You can see an application of this here.
academic_recipe #notice that when I run the object now, there is information about the operations that are done
## Recipe
##
## Inputs:
##
## role #variables
## outcome 1
## predictor 4
##
## Training data contained 800 data points and no missing data.
Now that we have our wrangled dataset, we need to actually pull out the information! academic_recipe_data
is useful, but we won’t be able to put this into a function (it’s considered a recipe
object, rather than the data frame it actually is). To make use of this data, we often have to juice()
it (aka: squeeze the data out of the recipe).
juice()
is a very straightforward function: it takes a recipe and spits out a data frame.
## # A tibble: 6 x 5
## favorite_count followers_count statuses_count verified retweet_count
## <dbl> <dbl> <dbl> <fct> <dbl>
## 1 0 12103 58119 FALSE 5
## 2 0 699 1893 FALSE 10
## 3 0 310 8149 FALSE 9
## 4 0 135 546 FALSE 5
## 5 14 3063 1157 FALSE 4
## 6 1 6765 18039 FALSE 0
One function we do not use, but may be of interest, is bake()
, which is used to apply a model to a new dataset.
Learn more about juice()
and bake()
here and here.
YOu can also learn more about recipes here.
27.3 Model specification, parsnip
Now that we have our data, let’s proceed with constructing our model! For this tutorial, we will focus on 3 algorithms (random forests, xgboost, and LASSO regression).
27.3.1 Random Forests
Unlike caret
, parsnip
does use specific functions to refer to specific algorithms. In this instance, we will use rand_forest()
, the parsnip
function to construct random forests. There are a couple different packages that have random forest algorithms, but we’ll use ranger
here.
Most parsnip
functions require at least 3 pieces of information. The first are hyperparameters (in this instance, mtry
, which is the number of variables randomly sampled at each split, trees
, which is the number of trees produced, and min_n
, which is the minimum number of points in a node needed to split further). The second is the mode (classification or regression), which we indicate using the set_mode()
function. Finally, we have to indicate the package that this algorithm comes from using set_engine()
. While most algorithms have only one package version, random forests are popular enough that they actually appear in a few packages. But for the purposes of this analysis, let’s use the ranger
package (it also exists in the randomForest
package)
random_forest_model <- rand_forest(mtry = 1, trees = 500, min_n = 2) %>%
set_mode("regression") %>%
set_engine("ranger")
While you can add the hyperparameters into the function, as we do here, it is also possible to use a tuning grid (which is more similar to caret
). If you’re interested in this, you can lean more about it here.
Great! So now, we have our wrangled data and we have our model, but we have yet to put it together. How do we do that?
Well, we do so using the fit()
function. As we discussed in class, fit()
takes 3 arguments: the model information (from parsnip
), the formula (what is the DV and what are the IVs?), and the data itself (juiced data from recipes
). In our case the random_forest_model
contains our hyperparameters and details the package we would like to use, and academic_recipe_data
is the wrangled data (centered with dummy variables).
rf_model <- fit(random_forest_model, retweet_count ~ ., academic_recipe_data)
#note that I use . to tell the computer to treat all the variables from the dataframe (aside from retweet_count) as predictors
## parsnip model object
##
## Ranger result
##
## Call:
## ranger::ranger(x = maybe_data_frame(x), y = y, mtry = min_cols(~1, x), num.trees = ~500, min.node.size = min_rows(~2, x), num.threads = 1, verbose = FALSE, seed = sample.int(10^5, 1))
##
## Type: Regression
## Number of trees: 500
## Sample size: 800
## Number of independent variables: 4
## Mtry: 1
## Target node size: 2
## Variable importance mode: none
## Splitrule: variance
## OOB prediction error (MSE): 44908.56
## R squared (OOB): 0.07728571
You’ll notice that when you print the model, it contains some important information about the model fit. The most important ones for random forest are the OOB prediction error and the OOB R-squared. OOB stands for “out of bag” (typically referring to OOB Error). In this instance, the OOB error suggests that this model is not great–if anything, it performs worse than random chance (50/50).
Let us now proceed with the XGBoost model.
27.3.2 XGBoost
Our second algorithm is XGBoost (which is available in the xgboost
package).As we discussed in class, xgboost
is a decision tree-based algorithm that seeks to resolve overfit problems.
xgboost_model <-
boost_tree(mtry = 1, trees = 1000, min_n = 2,
loss_reduction = 1) %>%
set_engine("xgboost") %>%
set_mode("regression")
Like the random forest model, we see many similar hyperparameters (xgboost, after all, is based on decision trees). However, we also include a loss_reduction
hyperparameters, which limits the depth of the tree.
Learn more about xgboost hyperparameters (including several we don’t tune) here!
Now that we have our model, let’s proceed with the fit()
function, like we did previously.
## parsnip model object
##
## ##### xgb.Booster
## raw: 1.1 Mb
## call:
## xgboost::xgb.train(params = list(eta = 0.3, max_depth = 6, gamma = 1,
## colsample_bytree = 1, colsample_bynode = 0.2, min_child_weight = 2,
## subsample = 1), data = x$data, nrounds = 1000, watchlist = x$watchlist,
## verbose = 0, nthread = 1, objective = "reg:squarederror")
## params (as set within xgb.train):
## eta = "0.3", max_depth = "6", gamma = "1", colsample_bytree = "1", colsample_bynode = "0.2", min_child_weight = "2", subsample = "1", nthread = "1", objective = "reg:squarederror", validate_parameters = "TRUE"
## xgb.attributes:
## niter
## callbacks:
## cb.evaluation.log()
## # of features: 5
## niter: 1000
## nfeatures : 5
## evaluation_log:
## iter training_rmse
## 1 229.79981
## 2 221.08087
## ---
## 999 51.69141
## 1000 51.68561
Unlike the random forest output, xgboost does not provide as much goodness of fit information. We’ll return to validating this model when we use yardstick
.
For now, though, we’ll move to our last algorithm, the LASSo regression.
27.3.3 LASSO Regression
As we discussed in class LASSO regressions build on the linear regression model by penalizing less useful independent variables. This is especially useful when you are using a regression for prediction (rather than inference).
To use the LASSO regression, we will use the linear_reg()
function in parsnip
, which dpends on the glmnet
package.
lasso_reg_model <- linear_reg(mixture = 1, penalty = 0.065) %>%
set_engine("glmnet") %>%
set_mode("regression")
As this is a very different model, we have notably different hyperparameters. First, we have the mixture
argument. If you use ?linear_reg
, you’ll notice that mixture
refers to the proportion of regularization (a 1
is a standard LASSO regression). YOu can also indicate the total amount of regularization using penalty
; the higher the penalty, the more the model is shrunk. In other words, when the penalty
is higher, more variables are likely to be penalized.
## Loading required package: Matrix
## Warning: package 'Matrix' was built under R version 4.3.1
##
## Attaching package: 'Matrix'
## The following objects are masked from 'package:tidyr':
##
## expand, pack, unpack
## Loaded glmnet 4.1-4
## # A tibble: 5 x 3
## term estimate penalty
## <chr> <dbl> <dbl>
## 1 (Intercept) 103. 0.065
## 2 favorite_count -0.368 0.065
## 3 followers_count -0.0000430 0.065
## 4 statuses_count -0.000141 0.065
## 5 verifiedTRUE -43.5 0.065
To learn more about the model, we use tidy()
, which provides us with the estimates. note that follower_count
an statuses_count
have been reduced to 0, suggesting that they do not help explain the retweets much. If you play around with teh penalty
argument, you’ll notice that the higher the penalty
you set, the more likely estimates are to be 0 (try, for example, a penalty of 0.04 or a penalty of 0.2).
Alright! So now that we’ve played around with our models, let’s actually talk about how we evaluate them.
27.4 Evaluation, yardstick
yardstick
is the main package used (in tidymodels
) to evaluate models. To assess the quality of our random forests, xgboost, and LASSO regression models, we will test the accuracy of the models on the test set, which we split apart at the beginning of this tutorial.
The main function in yardstick
used to predict on new data is, unsurprisingly, predict()
. To use it, we first must wrangle the data in accordance with our earlier recipe.
While we can rewrite the recipe here, the recipes
function does include a neat bake()
function, which allows you to apply a previous recipe to a new dataset. bake
takes at least two arguments: the original recipe object and the new dataset.
Now, we are ready to run our predict()
functions. We’ll run predict()
three times, one for each model, but let’s first see what the output looks like.
## # A tibble: 6 x 1
## .pred
## <dbl>
## 1 570.
## 2 254.
## 3 290.
## 4 43.1
## 5 168.
## 6 125.
As you can see, the output of predict()
is a data frame with one column (the prediction, in this case, of the retweets). To make a comparison, let’s attach the actual results (academic_test_data$retweet_count
) to each predict()
output.
lasso_predictions <- predict(lasso_model, academic_test_data) %>%
mutate(retweet_count = academic_test_data$retweet_count)
rf_predictions <- predict(rf_model, academic_test_data) %>%
mutate(retweet_count = academic_test_data$retweet_count)
xgb_predictions <- predict(xgb_model, academic_test_data) %>%
mutate(retweet_count = academic_test_data$retweet_count)
Now that we have our data, we are ready to proceed
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 245.
## 2 rsq standard 0.0191
## 3 mae standard 150.
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 229.
## 2 rsq standard 0.221
## 3 mae standard 135.
## # A tibble: 3 x 3
## .metric .estimator .estimate
## <chr> <chr> <dbl>
## 1 rmse standard 246.
## 2 rsq standard 0.110
## 3 mae standard 152.
The output of metrics()
(when the outcome variable, or the DV, is a continuous variable) contains three columns: the root mean squared error, the r-squared, and the mean absolute error. When comparing the rsq
and mae
, we can see that all three models are… not the greatest at predicting the retweetability of a message. But, on the bright side, you now have a basic understanding of the two major R packages for supervised machine learning: caret
and tidymodels
! While we used caret
last week for classification and tidymodels
this week for regression, it’s worth emphasizing that caret
obviously contains regression algorithms and tidymodels
obviously contains classification algorithms.
On a whole, I find that tidymodels
tends to be more “robust” (in the sense that there are a lot more ways to tune and adjust your data and model), whereas caret
contains a much larger library of supervised machine learning algorithms. Whichever you use, therefore, may be contingent on the goals of your data analysis!
Want to learn more about tidymodels
? Check out these links:
- tidymodels website
- tuning grids in tidymodel
- tune() and dial()
- workflows; this is great for streamlining your recipe and model!
- Resampling; aka: cross-validation
- caret v tidymodels