Chapter 24 kNN
The first algorithm we will learn is k-Nearest Neighbor, which learns by identifying the neighbors who are close to different labeled data. The assumption of this model is that similar data points will be closer to one another (this is similar to the logic underlying community detection or k-means analysis).
One advantage of kNN is that it is relatively simple (it requires very few hyper-parameters) and it can be especially useful for complex categorical data (as in, cases where you have more than 2 labels in a variable). However, kNN also takes some time to classify new data points, and the results vary greatly by its one hyper-parameter (k
).
To construct a kNN algorithm, we will use the train()
function in the caret
package. This is the workhorse function of the package–anytime you are training a new model, you will use the train()
function. train()
requires (generally) three types of information: (1) the data (x
and y
), (2) the algorithm (method
), and (3) the hyperparameters that are unique to each supervised machine learning algorithm (tuneGrid
). In kNN
, there is only one hyperparameter (k
), the number of neighbors that the algorithm will look at (for simplicity, we will use 10).
<- caret::train(x = tw_to_train, #training data
knn_model_con y = as.factor(conservative_code), #labeled data
method = "knn", #the algorithm
trControl = trctrl, #the resampling strategy we will use
tuneGrid = data.frame(k = 2) #the hyperparameter
)
print(knn_model_con) #print this model
## k-Nearest Neighbors
##
## 92 samples
## 1035 predictors
## 2 classes: '0', '1'
##
## No pre-processing
## Resampling: Bootstrapped (25 reps)
## Summary of sample sizes: 92, 92, 92, 92, 92, 92, ...
## Resampling results:
##
## Accuracy Kappa
## 0.4986282 -0.02114475
##
## Tuning parameter 'k' was held constant at a value of 2
Based on the information from knn_model_con
, we know the model was able to learn from 92 tweets (the training set), which had 1035 parameters and 2 labels (0
and 1
).
What we don’t know from this information, however, is the quality of the algorithm. To do that, we will have to turn to the test data.
24.1 Testing the Model
To apply this algorithm to the test data, let’s use predict()
which we learned about in our Advanced Linear Regression tutorial (Week 10).
<- predict(knn_model_con, newdata = tw_to_test) knn_predict
Instead of checking the percent accuracy of this data, however, we will learn to use a function from the caret
package: confusionMatrix()
. confusionMatrix()
is useful because it provides more than just the percent accuracy measure–it will report other measures that account for random chance, as well as the F-score, a common measurement of accuracy in supervised machine learning.
<- caret::confusionMatrix(knn_predict, conservative_data$conservative[-trainIndex], mode = "prec_recall")
knn_confusion_matrix knn_confusion_matrix
## Confusion Matrix and Statistics
##
## Reference
## Prediction 0 1
## 0 3 1
## 1 14 20
##
## Accuracy : 0.6053
## 95% CI : (0.4339, 0.7596)
## No Information Rate : 0.5526
## P-Value [Acc > NIR] : 0.314161
##
## Kappa : 0.139
##
## Mcnemar's Test P-Value : 0.001946
##
## Precision : 0.75000
## Recall : 0.17647
## F1 : 0.28571
## Prevalence : 0.44737
## Detection Rate : 0.07895
## Detection Prevalence : 0.10526
## Balanced Accuracy : 0.56443
##
## 'Positive' Class : 0
##
As you can see here, the accuracy of this supervised machine learning model is pretty weak (55%). When you account for the imbalance, the accuracy decreases to 51%, worse than if you randomly guessed the code. This is confirmed by the F-1 score (0.19), which is not great.
Remember when I said that kNN
varies by its hyper-parameter? Try this out for yourself by changing the value of k
in the movel above and checking the accuracy score and F1 score.
Want to learn more about kNN? Check out these tutorials:
* Towards Data Science tutorial
* kNN for dummies explanation
* kNN with non-text data