The almighty Caret!
In this practical you’ll use the caret
(Classification And REgression Training) package to automate many aspects of the machine learning process. By the end of this practical you will know:
packageYou’ll use four datasets in this practical: boston_train_csv
and boston_test.csv
, heartdisease_train.csv
and heartdisease_test.csv
. They available in the
file available through the main course page. If you haven’t already, download the
folder and unzip it to get the files. We also recommend saving them in a folder called data
in your working directory!
For this practical, you will need the following packages. If you don’t have any installed, you’ll need to install them with install.packages()
# Packages you'll need for this practical
The following set of example code will take you through the basic steps of machine learning using the amazing caret
# Load packages
# ---------------
# Setup
# ---------------
# Split diamonds data into separate training and test datasets
diamonds <- diamonds %>%
sample_frac(1) # Randomly sort
# Create separate training and test data
diamonds_train <- diamonds %>%
diamonds_test <- diamonds %>%
# ---------------
# Explore
# ---------------
# Explore columns with summarizeColumns
# Visualise relationships with ggpairs
# ---------------
# Train
# ---------------
# Set up control values
rep10_control <- trainControl(method = "repeatedcv",
number = 10, # 10 folds
repeats = 5) # Repeat 5 times
# Predict price with linear regression
diamonds_lm_train <- train(form = price ~ .,
data = diamonds_train,
method = "lm",
trControl = rep10_control)
# Explore the train object
# Look at variable importance with varImp
# Save final predictions
diamonds_lm_predictions <- predict(diamonds_lm_train,
newdata = diamonds_test)
# ---------------
# Evaluate
# ---------------
# Plot relationship between predictions and truth
performance_data <- tibble(predictions = diamonds_lm_predictions,
criterion = diamonds_test$price)
ggplot(data = performance_data,
aes(x = predictions, y = criterion)) +
geom_point() + # Add points
geom_abline(slope = 1, intercept = 0, col = "blue", size = 2) +
labs(title = "Regression prediction accuracy",
subtitle = "Blue line is perfect prediction!")
# Look at final prediction performance!
postResample(pred = diamonds_lm_predictions,
obs = diamonds_test$price)
The boston
dataset contains housing data for 506 census tracts of Boston from the 1970 census.
Here is a description of each of the variables in the dataset
Column | Description |
crim | per capita crime rate by town |
zn | proportion of residential land zoned for lots over 25,000 sq.ft |
indus | proportion of non-retail business acres per town |
chas | Charles River dummy variable (= 1 if tract bounds river; 0 otherwise) |
nox | nitric oxides concentration (parts per 10 million) |
rm | average number of rooms per dwelling |
age | proportion of owner-occupied units built prior to 1940 |
dis | weighted distances to five Boston employment centres |
rad | index of accessibility to radial highways |
tax | full-value property-tax rate per USD 10,000 |
ptratio | pupil-teacher ratio by town |
b | 1000(B - 0.63)^2 where B is the proportion of blacks by town |
lstat | percentage of lower status of the population |
medv | median value of owner-occupied homes in USD 1000’s |
into R as a new object with the name boston_train
using read_csv()
.boston_train <- read_csv("data/boston_train.csv")
, View()
, names()
and summary()
# A tibble: 6 x 14
crim zn indus chas nox rm age dis rad tax ptratio
<dbl> <dbl> <dbl> <int> <dbl> <dbl> <dbl> <dbl> <int> <int> <dbl>
1 3.54 0. 19.6 1 0.871 6.15 82.6 1.75 5 403 14.7
2 0.340 0. 21.9 0 0.624 6.46 98.9 2.12 4 437 21.2
3 0.0798 40. 6.41 0 0.447 6.48 32.1 4.14 4 254 17.6
4 0.773 0. 8.14 0 0.538 6.50 94.4 4.45 4 307 21.0
5 0.330 0. 6.20 0 0.507 6.09 61.5 3.65 8 307 17.4
6 0.103 30. 4.93 0 0.428 6.36 52.9 7.04 6 300 16.6
# ... with 3 more variables: b <dbl>, lstat <dbl>, medv <dbl>
function from the mlr
package to create numerical summaries of each column. Do you notice any strange columns?mlr::summarizeColumns(boston_train)
name type na mean disp median mad
1 crim numeric 0 4.017157 10.6362176 0.239745 0.3043481
2 zn numeric 0 12.835000 24.4253196 0.000000 0.0000000
3 indus numeric 0 11.644100 7.5021709 9.900000 11.4011940
4 chas integer 0 0.090000 0.2876235 0.000000 0.0000000
5 nox numeric 0 0.568709 0.1260641 0.544000 0.1319514
6 rm numeric 0 6.346150 0.7002630 6.278000 0.4885167
7 age numeric 0 70.893000 27.4259748 82.200000 24.2405100
8 dis numeric 0 3.493716 2.0120118 2.800100 1.4820070
9 rad integer 0 9.670000 8.8740400 5.000000 2.9652000
10 tax integer 0 415.710000 173.7767809 364.000000 139.3644000
11 ptratio numeric 0 18.289000 2.2834448 18.600000 2.3721600
12 b numeric 0 362.635300 84.7712980 390.720000 9.1624680
13 lstat numeric 0 12.349100 7.1234023 10.585000 6.1083120
14 medv numeric 0 23.326000 9.6243758 21.800000 5.9304000
min max nlevs
1 0.01096 73.5341 0
2 0.00000 95.0000 0
3 0.46000 27.7400 0
4 0.00000 1.0000 0
5 0.38900 0.8710 0
6 4.13800 8.3980 0
7 7.80000 100.0000 0
8 1.17420 9.2203 0
9 1.00000 24.0000 0
10 188.00000 711.0000 0
11 12.60000 21.2000 0
12 16.45000 396.9000 0
13 2.47000 34.7700 0
14 5.00000 50.0000 0
# Nope they all look ok!
Can you think of a reason why you should remove any of the columns? If so, remove them using select()
! There’s no need to include unnecessary data in a machine learning algorithm :)
Visualise the data using the ggpairs()
function from the GGally
package. Do you notice any interesting patterns?
We will use 10-fold cross validation to train the models. To set up the fitting process, we’ll need to use the trainControl
function. Look at the help menu for this function and look through the examples to see how it works.
Using the trainControl()
function, create an object called rep10_control
that will conduct 10-fold cross validation.
rep10_control <- caret::trainControl(method = "repeatedcv",
number = 10,
search = "grid") # Find optimal parameters
function. Look at the help menu for the train()
function to see all of its lovely arguments and examples.?train
where you predict a town’s median home value medv
based on all other variables. In order to conduct linear regression, include the arguments method = 'lm'
and trControl = rep10_control
boston_lm_train <- caret::train(form = medv ~ .,
data = boston_train,
method = "lm",
trControl = rep10_control)
argument using the generic functions names()
, summary()
, and plot()
. What is stored in this object?boston_lm_train
Linear Regression
100 samples
13 predictor
No pre-processing
Resampling: Cross-Validated (10 fold, repeated 1 times)
Summary of sample sizes: 90, 90, 91, 89, 90, 91, ...
Resampling results:
RMSE Rsquared MAE
5.896511 0.667127 4.522078
Tuning parameter 'intercept' was held constant at a value of TRUE
lm(formula = .outcome ~ ., data = dat)
Min 1Q Median 3Q Max
-10.548 -2.999 -0.452 2.584 18.646
Estimate Std. Error t value Pr(>|t|)
(Intercept) 53.236754 11.969055 4.448 2.58e-05 ***
crim -0.150641 0.062897 -2.395 0.018790 *
zn 0.097746 0.036441 2.682 0.008767 **
indus -0.064170 0.139644 -0.460 0.647019
chas 0.871467 1.984693 0.439 0.661694
nox -18.843021 9.448875 -1.994 0.049297 *
rm 1.983625 1.006865 1.970 0.052045 .
age 0.039229 0.043431 0.903 0.368918
dis -2.214633 0.579946 -3.819 0.000253 ***
rad 0.308793 0.139984 2.206 0.030056 *
tax -0.004699 0.008022 -0.586 0.559580
ptratio -1.254161 0.332930 -3.767 0.000302 ***
b 0.012235 0.007149 1.712 0.090587 .
lstat -0.756288 0.118553 -6.379 8.61e-09 ***
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 5.017 on 86 degrees of freedom
Multiple R-squared: 0.764, Adjusted R-squared: 0.7283
F-statistic: 21.41 on 13 and 86 DF, p-value: < 2.2e-16
function, look at the variable importance of each predictor. Which predictors seem to be the most important? After you’ve run the function, try plotting the object with plot()
to visualise the results!caret::varImp(boston_lm_train)
lm variable importance
lstat 100.0000
dis 56.8933
ptratio 56.0239
zn 37.7630
crim 32.9271
rad 29.7434
nox 26.1794
rm 25.7735
b 21.4207
age 7.8137
tax 2.4688
indus 0.3439
chas 0.0000
into R using read_csv()
boston_test <- read_csv("data/boston_test.csv")
, predict the criterion values for the test dataset boston_test
. In the object
argument, use your boston_lm_train
model, in the newdata
argument, use your boston_test
test data. Save the results as the vector boston_lm_predictions
boston_lm_predictions <- predict(boston_lm_train,
newdata = boston_test)
# Combine predictions and criterion in one tibble
performance_data <- tibble(predictions = boston_lm_predictions,
criterion = boston_test$medv)
# Plot results
ggplot(data = performance_data,
aes(x = predictions, y = criterion)) +
geom_point() + # Add points
geom_abline(slope = 1, intercept = 0, col = "blue", size = 2) +
labs(title = "Regression prediction accuracy",
subtitle = "Blue line is perfect prediction!")
. Start by looking at the help menu for the postResample()
function to see its arguments.?postResample
. You should only specify the pred
and obs
arguments as pred = boston_lm_predictions
and obs = boston_test$medv
.postResample(pred = boston_lm_predictions,
obs = boston_test$medv)
RMSE Rsquared MAE
5.1690046 0.6859913 3.8336613
method = 'lm'
, we’ll use so–called ‘ridge regression’. Ridge regression is a variation of ‘standard’ linear regression that tries to avoid overfitting and thus make better predictions. How can you use ridge regression in caret? Try to answer yourself by looking at the list of all available models on the caret help site: Search for ‘ridge’ and find the method
value.# It's method = "ridge" :)
instead of _regression_
so you know which object is which!. As you are creating your new objects, make sure to stop and explore them (i.e.; with summary()
, plot()
, print()
to see how they look different from the previous objects.boston_ridge_train <- caret::train(form = medv ~ .,
data = boston_train,
method = "ridge",
trControl = rep10_control)
# VarImp() doesn't work for ridge regression
# caret::varImp(boston_ridge_train)
boston_ridge_predictions <- predict(boston_ridge_train,
newdata = boston_test)
# Combine predictions and criterion in one tibble
performance_data <- tibble(predictions = boston_ridge_predictions,
criterion = boston_test$medv)
# Plot results
ggplot(data = performance_data,
aes(x = predictions, y = criterion)) +
geom_point() + # Add points
geom_abline(slope = 1, intercept = 0, col = "blue", size = 2) +
labs(title = "Ridge Regression prediction accuracy",
subtitle = "Blue line is perfect prediction!")
postResample(pred = boston_ridge_predictions,
obs = boston_test$medv)
RMSE Rsquared MAE
4.9313886 0.7125253 3.5123396
# Looks like ridge regression did better with an RMSE of 4.93 comapred to 5.19 for standard regression
data.boston_rf_train <- caret::train(form = medv ~ .,
data = boston_train,
method = "rf",
trControl = rep10_control)
# VarImp() doesn't work for rf regression
# caret::varImp(boston_rf_train)
boston_rf_predictions <- predict(boston_rf_train,
newdata = boston_test)
# Combine predictions and criterion in one tibble
performance_data <- tibble(predictions = boston_rf_predictions,
criterion = boston_test$medv)
# Plot results
ggplot(data = performance_data,
aes(x = predictions, y = criterion)) +
geom_point() + # Add points
geom_abline(slope = 1, intercept = 0, col = "blue", size = 2) +
labs(title = "Random Forests prediction accuracy",
subtitle = "Blue line is perfect prediction!")
postResample(pred = boston_rf_predictions,
obs = boston_test$medv)
RMSE Rsquared MAE
4.2500196 0.7846074 2.9331980
Now it’s time to try a classification problem. We’ll do this with the heartdisease
data. The heartdisease
dataset contains data from 303 patients suspected of having heart disease The objective is to predict the diagnosis
column indicating whether a patient has heartdisease or not
Here is a description of each of the variables in the dataset
Column | Description |
age | Age |
sex | Sex, 1 = male, 0 = female |
cp | Chest pain type: ta = typical angina, aa = atypical angina, np = non-anginal pain, a = asymptomatic |
trestbps | Resting blood pressure (in mm Hg on admission to the hospital) |
chol | Serum cholestoral in mg/dl |
fbs | Fasting blood sugar > 120 mg/dl: 1 = true, 0 = false |
restecg | Resting electrocardiographic results. “normal” = normal, “abnormal” = having ST-T wave abnormality (T wave inversions and/or ST elevation or depression of > 0.05 mV), “hypertrophy” = showing probable or definite left ventricular hypertrophy by Estes’ criteria. |
thalach | Maximum heart rate achieved |
exang | Exercise induced angina: 1 = yes, 0 = no |
oldpeak | ST depression induced by exercise relative to rest |
slope | The slope of the peak exercise ST segment. |
ca | Number of major vessels (0-3) colored by flourosopy |
thal | “normal” = normal, “fd” = fixed defect, “rd” = reversible defect |
diagnosis | 1 = Heart disease, 0 = No Heart disease |
into R as a new object with the name heartdisease_train
using read_csv()
.heartdisease_train <- read_csv("data/heartdisease_train.csv")
column to a factor. This will be necessary for some machine learning models to work!# Convert diagnosis to a factor
# This is important for some machine learning models to do
# classification analyses
heartdisease_train <- heartdisease_train %>%
mutate(diagnosis = factor(diagnosis)) # Convert diagnosis to a factor
Warning: package 'bindrcpp' was built under R version 3.4.4
# A tibble: 100 x 14
age sex cp trestbps chol fbs restecg thalach exang oldpeak
<int> <int> <chr> <int> <int> <int> <chr> <int> <int> <dbl>
1 54 0 np 160 201 0 normal 163 0 0.
2 54 0 np 135 304 1 normal 170 0 0.
3 34 0 aa 118 210 0 normal 192 0 0.700
4 57 1 a 152 274 0 normal 88 1 1.20
5 60 0 a 150 258 0 hypertrop… 157 0 2.60
6 56 1 a 125 249 1 hypertrop… 144 1 1.20
7 63 1 a 130 254 0 hypertrop… 147 0 1.40
8 64 1 np 125 309 0 normal 131 1 1.80
9 58 0 a 170 225 1 hypertrop… 146 1 2.80
10 63 1 ta 145 233 1 hypertrop… 150 0 2.30
# ... with 90 more rows, and 4 more variables: slope <chr>, ca <int>,
# thal <chr>, diagnosis <fct>
name type na mean disp median mad min max
1 age integer 0 54.340 9.0299390 56.50 9.63690 34 77.0
2 sex integer 0 0.720 0.4512609 1.00 0.00000 0 1.0
3 cp character 0 NA 0.4900000 NA NA 10 51.0
4 trestbps integer 0 133.100 16.2552245 130.00 14.82600 100 178.0
5 chol integer 0 235.840 42.1099379 233.00 42.99540 149 340.0
6 fbs integer 0 0.150 0.3588703 0.00 0.00000 0 1.0
7 restecg character 0 NA 0.4800000 NA NA 1 52.0
8 thalach integer 0 149.250 22.1880735 151.50 25.94550 88 194.0
9 exang integer 0 0.410 0.4943111 0.00 0.00000 0 1.0
10 oldpeak numeric 0 1.064 1.1966722 0.65 0.96369 0 4.2
11 slope character 0 NA 0.5200000 NA NA 7 48.0
12 ca integer 0 0.760 0.9547330 0.00 0.00000 0 3.0
13 thal character 0 NA 0.4500000 NA NA 8 55.0
14 diagnosis factor 0 NA 0.4500000 NA NA 45 55.0
1 0
2 0
3 4
4 0
5 0
6 0
7 3
8 0
9 0
10 0
11 3
12 0
13 3
14 2
heartdisease_test <- read_csv("data/heartdisease_test.csv")
heartdisease_rpart_train <- caret::train(form = diagnosis ~ .,
data = heartdisease_train,
method = "rpart",
trControl = rep10_control)
heartdisease_rpart_predictions <- predict(heartdisease_rpart_train,
newdata = heartdisease_test)
data, make sure to convert the diagnosis
column to a factor like you did for the heartdisease_train
data.# -----------------------------------------------------------
# Visualise the relationship between two nominal variables
# -----------------------------------------------------------
# Combine predictions and criterion in one tibble
performance_data <- tibble(predictions = heartdisease_rpart_predictions,
criterion = heartdisease_test$diagnosis)
# Plot results
ggplot(data = performance_data,
aes(x = predictions, ..count..)) +
geom_bar(aes(fill = criterion), position = "dodge") +
labs(title = "CART prediction accuracy")
postResample(pred = heartdisease_rpart_predictions,
obs = factor(heartdisease_test$diagnosis))
heartdisease_rf_train <- caret::train(form = diagnosis ~ .,
data = heartdisease_train,
method = "rf",
trControl = rep10_control)
heartdisease_rf_predictions <- predict(heartdisease_rf_train,
newdata = heartdisease_test)
# Combine predictions and criterion in one tibble
performance_data <- tibble(predictions = heartdisease_rf_predictions,
criterion = heartdisease_test$diagnosis)
# Plot results
ggplot(data = performance_data,
aes(x = predictions, ..count..)) +
geom_bar(aes(fill = criterion), position = "dodge") +
labs(title = "rf prediction accuracy")
postResample(pred = heartdisease_rf_predictions,
obs = factor(heartdisease_test$diagnosis))
Accuracy Kappa
0.7980296 0.5943364
(again check out to see them all). Look for the craziest looking models you can find, and see how well they do on the two datasets in this practical. Can you find one that does much better than regression, random forests, or decision trees?Max Kuhn, the author of caret
has a fantastic overview of the package at If you like the caret
package as much as we do, be sure to go through this page in detail.
Max Kuhn is also the co-author of a fantastic book on machine learning called Applied predictive modelling -