Feedback should be send to goran.milovanovic@datakolektiv.com
. These notebooks accompany the Intro to Data Science: Non-Technical Background course 2020/21.
We will first consider cross-validation in classification problems and contrast it with previously introduced approaches to model selection. Then we begin to go beyond Linear Models: we will introduce Decision Trees for classification and regression problems. In this session we go for an intuitive and practical approach to Decision Trees in R; in the next session we will introduce the basic elements of Information Theory and dig deeper into the theory of Decision Trees and even more powerful Random Forests.
install.packages('rpart')
install.packages ('rpart.plot')
Grab the HR_comma_sep.csv
dataset from the Kaggle and place it in your _data
directory for this session.
dataDir <- paste0(getwd(), "/_data/")
library(tidyverse)
Registered S3 methods overwritten by 'dbplyr':
method from
print.tbl_lazy
print.tbl_sql
── Attaching packages ────────────────────────────────────────────────────────────────────────────── tidyverse 1.3.1 ──
✔ ggplot2 3.3.5 ✔ purrr 0.3.4
✔ tibble 3.1.6 ✔ dplyr 1.0.8
✔ tidyr 1.2.0 ✔ stringr 1.4.0
✔ readr 2.0.2 ✔ forcats 0.5.1
── Conflicts ───────────────────────────────────────────────────────────────────────────────── tidyverse_conflicts() ──
✖ dplyr::filter() masks stats::filter()
✖ dplyr::lag() masks stats::lag()
library(data.table)
Registered S3 method overwritten by 'data.table':
method from
print.data.table
data.table 1.14.2 using 1 threads (see ?getDTthreads). Latest news: r-datatable.com
**********
This installation of data.table has not detected OpenMP support. It should still work but in single-threaded mode.
This is a Mac. Please read https://mac.r-project.org/openmp/. Please engage with Apple and ask them for support. Check r-datatable.com for updates, and our Mac instructions here: https://github.com/Rdatatable/data.table/wiki/Installation. After several years of many reports of installation problems on Mac, it's time to gingerly point out that there have been no similar problems on Windows or Linux.
**********
Attaching package: ‘data.table’
The following objects are masked from ‘package:dplyr’:
between, first, last
The following object is masked from ‘package:purrr’:
transpose
library(rpart)
library(rpart.plot)
Consider the HR_comma_sep.csv
dataset:
dataSet <- read.csv(paste0('_data/', 'HR_comma_sep.csv'),
header = T,
check.names = F,
stringsAsFactors = F)
head(dataSet)
table(dataSet$left)
0 1
11428 3571
The task is to predict the value of left
- whether the employee has left the company or not - from a set of predictors encompassing the following:
glimpse(dataSet)
Rows: 14,999
Columns: 10
$ satisfaction_level <dbl> 0.38, 0.80, 0.11, 0.72, 0.37, 0.41, 0.10, 0.92, 0.89, 0.42, 0.45, 0.11, 0.84, 0.41, 0.3…
$ last_evaluation <dbl> 0.53, 0.86, 0.88, 0.87, 0.52, 0.50, 0.77, 0.85, 1.00, 0.53, 0.54, 0.81, 0.92, 0.55, 0.5…
$ number_project <int> 2, 5, 7, 5, 2, 2, 6, 5, 5, 2, 2, 6, 4, 2, 2, 2, 2, 4, 2, 5, 6, 2, 6, 2, 2, 5, 4, 2, 2, …
$ average_montly_hours <int> 157, 262, 272, 223, 159, 153, 247, 259, 224, 142, 135, 305, 234, 148, 137, 143, 160, 25…
$ time_spend_company <int> 3, 6, 4, 5, 3, 3, 4, 5, 5, 3, 3, 4, 5, 3, 3, 3, 3, 6, 3, 5, 4, 3, 4, 3, 3, 5, 5, 3, 3, …
$ Work_accident <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ left <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, …
$ promotion_last_5years <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, …
$ sales <chr> "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales…
$ salary <chr> "low", "medium", "medium", "low", "low", "low", "low", "low", "low", "low", "low", "low…
Let’s formulate a Binomial Logistic Regression model to try to predict left
from satisfaction_level
, last_evaluation
, sales
, and salary
:
# - setups
dataSet$left <- factor(dataSet$left,
levels = c(0, 1))
dataSet$salary <- factor(dataSet$salary)
dataSet$salary <- relevel(dataSet$salary,
ref = 'high')
dataSet$sales <- factor(dataSet$sales)
dataSet$sales <- relevel(dataSet$sales,
ref = 'RandD')
# - model
blr_model1 <- glm(left ~ satisfaction_level + last_evaluation + sales + salary,
data = dataSet,
family = "binomial")
modelsummary <- summary(blr_model1)
print(modelsummary)
Call:
glm(formula = left ~ satisfaction_level + last_evaluation + sales +
salary, family = "binomial", data = dataSet)
Deviance Residuals:
Min 1Q Median 3Q Max
-1.7093 -0.6970 -0.4741 -0.1928 2.7917
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -1.46022 0.18380 -7.945 1.95e-15 ***
satisfaction_level -3.87156 0.08923 -43.387 < 2e-16 ***
last_evaluation 0.53218 0.12208 4.359 1.31e-05 ***
salesaccounting 0.67072 0.14023 4.783 1.73e-06 ***
saleshr 0.88149 0.13928 6.329 2.47e-10 ***
salesIT 0.49930 0.13124 3.804 0.000142 ***
salesmanagement 0.26267 0.16531 1.589 0.112068
salesmarketing 0.65029 0.13890 4.682 2.84e-06 ***
salesproduct_mng 0.51352 0.13853 3.707 0.000210 ***
salessales 0.62675 0.11422 5.487 4.08e-08 ***
salessupport 0.66090 0.12001 5.507 3.65e-08 ***
salestechnical 0.67753 0.11769 5.757 8.57e-09 ***
salarylow 1.77870 0.12287 14.476 < 2e-16 ***
salarymedium 1.28847 0.12400 10.391 < 2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 16465 on 14998 degrees of freedom
Residual deviance: 13733 on 14985 degrees of freedom
AIC: 13761
Number of Fisher Scoring iterations: 5
And take a look at the regression coefficients:
exp(coefficients(blr_model1))
(Intercept) satisfaction_level last_evaluation salesaccounting saleshr salesIT
0.23218577 0.02082584 1.70263484 1.95564264 2.41448902 1.64756456
salesmanagement salesmarketing salesproduct_mng salessales salessupport salestechnical
1.30039267 1.91608809 1.67116051 1.87152647 1.93652695 1.96900961
salarylow salarymedium
5.92213412 3.62722234
The Akaike Information Criterion is:
blr_model1$aic
[1] 13761.14
Now consider a model encompassing all predictors from the HR_comma_sep.csv
:
# - setups
dataSet$Work_accident <- factor(dataSet$Work_accident,
levels = c(0, 1))
dataSet$promotion_last_5years <- factor(dataSet$promotion_last_5years,
levels = c(0, 1))
# - model
blr_model2 <- glm(left ~ .,
data = dataSet,
family = "binomial")
modelsummary <- summary(blr_model2)
print(modelsummary)
Call:
glm(formula = left ~ ., family = "binomial", data = dataSet)
Deviance Residuals:
Min 1Q Median 3Q Max
-2.2248 -0.6645 -0.4026 -0.1177 3.0688
Coefficients:
Estimate Std. Error z value Pr(>|z|)
(Intercept) -2.0586521 0.2035965 -10.111 < 2e-16 ***
satisfaction_level -4.1356889 0.0980538 -42.178 < 2e-16 ***
last_evaluation 0.7309032 0.1491787 4.900 9.61e-07 ***
number_project -0.3150787 0.0213248 -14.775 < 2e-16 ***
average_montly_hours 0.0044603 0.0005161 8.643 < 2e-16 ***
time_spend_company 0.2677537 0.0155736 17.193 < 2e-16 ***
Work_accident1 -1.5298283 0.0895473 -17.084 < 2e-16 ***
promotion_last_5years1 -1.4301364 0.2574958 -5.554 2.79e-08 ***
salesaccounting 0.5823659 0.1448848 4.020 5.83e-05 ***
saleshr 0.8147438 0.1439439 5.660 1.51e-08 ***
salesIT 0.4016480 0.1355936 2.962 0.00306 **
salesmanagement 0.1339423 0.1704829 0.786 0.43206
salesmarketing 0.5702777 0.1445326 3.946 7.96e-05 ***
salesproduct_mng 0.4291129 0.1428822 3.003 0.00267 **
salessales 0.5435800 0.1181590 4.600 4.22e-06 ***
salessupport 0.6323910 0.1241337 5.094 3.50e-07 ***
salestechnical 0.6525123 0.1217267 5.360 8.30e-08 ***
salarylow 1.9440627 0.1286272 15.114 < 2e-16 ***
salarymedium 1.4132244 0.1293534 10.925 < 2e-16 ***
---
Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1
(Dispersion parameter for binomial family taken to be 1)
Null deviance: 16465 on 14998 degrees of freedom
Residual deviance: 12850 on 14980 degrees of freedom
AIC: 12888
Number of Fisher Scoring iterations: 5
exp(coefficients(blr_model2))
(Intercept) satisfaction_level last_evaluation number_project average_montly_hours
0.12762588 0.01599164 2.07695560 0.72973146 1.00447026
time_spend_company Work_accident1 promotion_last_5years1 salesaccounting saleshr
1.30702513 0.21657284 0.23927628 1.79026897 2.25859684
salesIT salesmanagement salesmarketing salesproduct_mng salessales
1.49428520 1.14332679 1.76875818 1.53589447 1.72216110
salessupport salestechnical salarylow salarymedium
1.88210526 1.92035920 6.98708012 4.10918362
The Akaike Information Criterion is:
blr_model2$aic
[1] 12887.9
Let’s cross-validate our blr_model1
and blr_model2
now. We will perform the k-fold CV in the following way:
i
in folds
: estimate the model on the remaining dataSet[-i, ] folds taken togetheri
from the fitted model,i
,Here we go: define folds first.
dataSet$fold <- sample(1:4, size = dim(dataSet)[1], replace = T)
table(dataSet$fold)
1 2 3 4
3647 3762 3790 3800
First for the narrower model:
cv1 <- lapply(1:4, function(x) {
# - test and train datasets
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
# - model on the training dataset
blrModel <- glm(left ~ satisfaction_level + last_evaluation + sales + salary,
data = train,
family = "binomial")
# - predict on the test dataset
predictions <- predict(blrModel,
newdata = test,
type = "response")
predictions <- ifelse(predictions > .5, 1, 0)
# - ROC analysis
acc <- sum(test$left == predictions)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & predictions == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & predictions == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
cv1 <- rbindlist(cv1)
cv1$fold <- 1:4
cv1 <- tidyr::pivot_longer(cv1,
cols = -fold,
names_to = 'measure',
values_to = 'value')
cv1$model <- 1
print(cv1)
Now for the full model:
cv2 <- lapply(1:4, function(x) {
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
blrModel <- glm(left ~ .,
data = train,
family = "binomial")
predictions <- predict(blrModel,
newdata = test,
type = "response")
predictions <- ifelse(predictions > .5, 1, 0)
acc <- sum(test$left == predictions)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & predictions == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & predictions == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
cv2 <- rbindlist(cv2)
cv2$fold <- 1:4
cv2 <- tidyr::pivot_longer(cv2,
cols = -fold,
names_to = 'measure',
values_to = 'value')
cv2$model <- 2
print(cv2)
Compare:
modelSelection <- rbind(cv1, cv2)
modelSelection$model <- ifelse(modelSelection$model == 1,
"Partial", "Full")
modelSelection$model <- factor(modelSelection$model)
ggplot(data = modelSelection,
aes(x = fold,
y = value,
group = model,
color = model,
fill = model)) +
geom_path(size = .5) +
geom_point(size = 3) +
scale_color_manual(values = c('darkred', 'darkorange')) +
ylim(0, 1) +
facet_wrap(~measure) +
theme_bw() +
theme(panel.border = element_blank()) +
theme(legend.text = element_text(size = 20)) +
theme(legend.title = element_text(size = 20)) +
theme(legend.position = "top") +
theme(strip.background = element_blank()) +
theme(strip.text = element_text(size = 20)) +
theme(axis.title.x = element_text(size = 18)) +
theme(axis.title.y = element_text(size = 18)) +
theme(axis.text.x = element_text(size = 17)) +
theme(axis.text.y = element_text(size = 17))
The average ROC from k-fold CV for both models:
modelSelection %>%
dplyr::select(-fold) %>%
dplyr::group_by(model, measure) %>%
dplyr::summarise(mean = round(mean(value), 5)) %>%
tidyr::pivot_wider(id_cols = model,
names_from = 'measure',
values_from = 'mean')
`summarise()` has grouped output by 'model'. You can override using the `.groups` argument.
Suppose we set the decision trashold to be p = .2
. First for the narrow model:
cv1 <- lapply(1:4, function(x) {
# - test and train datasets
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
# - model on the training dataset
blrModel <- glm(left ~ satisfaction_level + last_evaluation + sales + salary,
data = train,
family = "binomial")
# - predict on the test dataset
predictions <- predict(blrModel,
newdata = test,
type = "response")
predictions <- ifelse(predictions > .2, 1, 0)
# - ROC analysis
acc <- sum(test$left == predictions)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & predictions == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & predictions == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
cv1 <- rbindlist(cv1)
cv1$fold <- 1:4
cv1 <- tidyr::pivot_longer(cv1,
cols = -fold,
names_to = 'measure',
values_to = 'value')
cv1$model <- 1
Now for the full model, decision treshold is p = .2
:
cv2 <- lapply(1:4, function(x) {
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
blrModel <- glm(left ~ .,
data = train,
family = "binomial")
predictions <- predict(blrModel,
newdata = test,
type = "response")
predictions <- ifelse(predictions > .2, 1, 0)
acc <- sum(test$left == predictions)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & predictions == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & predictions == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
cv2 <- rbindlist(cv2)
cv2$fold <- 1:4
cv2 <- tidyr::pivot_longer(cv2,
cols = -fold,
names_to = 'measure',
values_to = 'value')
cv2$model <- 2
Compare:
modelSelection <- rbind(cv1, cv2)
modelSelection$model <- ifelse(modelSelection$model == 1,
"Partial", "Full")
modelSelection$model <- factor(modelSelection$model)
ggplot(data = modelSelection,
aes(x = fold,
y = value,
group = model,
color = model,
fill = model)) +
geom_path(size = .5) +
geom_point(size = 3) +
scale_color_manual(values = c('darkred', 'darkorange')) +
ylim(0, 1) +
facet_wrap(~measure) +
theme_bw() +
theme(panel.border = element_blank()) +
theme(legend.text = element_text(size = 20)) +
theme(legend.title = element_text(size = 20)) +
theme(legend.position = "top") +
theme(strip.background = element_blank()) +
theme(strip.text = element_text(size = 20)) +
theme(axis.title.x = element_text(size = 18)) +
theme(axis.title.y = element_text(size = 18)) +
theme(axis.text.x = element_text(size = 17)) +
theme(axis.text.y = element_text(size = 17))
Across a range of decision criteria, dec_criterion <- seq(.01, .99, .01)
, initial model with four predictors:
cv1 <- lapply(1:4, function(x) {
# - test and train datasets
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
# - model on the training dataset
blrModel <- glm(left ~ satisfaction_level + last_evaluation + sales + salary,
data = train,
family = "binomial")
# - predict on the test dataset
predictions <- predict(blrModel,
newdata = test,
type = "response")
dec_criterion <- seq(.01, .99, .01)
predictions <- lapply(dec_criterion, function(y) {
return(
ifelse(predictions > y, 1, 0)
)
})
predictions <- t(Reduce(rbind, predictions))
roc <- apply(predictions, 2, function(y) {
# - ROC analysis
acc <- sum(test$left == y)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & y == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & y == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
roc <- rbindlist(roc)
roc$dec_criterion <- dec_criterion
roc$fold <- x
return(roc)
})
cv1 <- rbindlist(cv1)
cv1 <- cv1 %>%
dplyr::group_by(dec_criterion) %>%
dplyr::summarise(acc = mean(acc),
hit = mean(hit),
fa = mean(fa))
cv1 <- tidyr::pivot_longer(cv1,
cols = -dec_criterion,
names_to = 'measure',
values_to = 'value')
cv1$model <- 1
For the full model, dec_criterion <- seq(.01, .99, .01)
:
cv2 <- lapply(1:4, function(x) {
# - test and train datasets
test <- dataSet %>%
dplyr::filter(fold == x) %>%
dplyr::select(-fold)
train <- dataSet %>%
dplyr::filter(fold != x) %>%
dplyr::select(-fold)
# - model on the training dataset
blrModel <- glm(left ~ .,
data = train,
family = "binomial")
# - predict on the test dataset
predictions <- predict(blrModel,
newdata = test,
type = "response")
dec_criterion <- seq(.01, .99, .01)
predictions <- lapply(dec_criterion, function(y) {
return(
ifelse(predictions > y, 1, 0)
)
})
predictions <- t(Reduce(rbind, predictions))
roc <- apply(predictions, 2, function(y) {
# - ROC analysis
acc <- sum(test$left == y)
acc <- acc/dim(test)[1]
hit <- sum(test$left == 1 & y == 1)
hit <- hit/sum(test$left == 1)
fa <- sum(test$left == 0 & y == 1)
fa <- fa/sum(test$left == 0)
return(data.frame(acc, hit, fa))
})
roc <- rbindlist(roc)
roc$dec_criterion <- dec_criterion
roc$fold <- x
return(roc)
})
cv2 <- rbindlist(cv2)
cv2 <- cv2 %>%
dplyr::group_by(dec_criterion) %>%
dplyr::summarise(acc = mean(acc),
hit = mean(hit),
fa = mean(fa))
cv2 <- tidyr::pivot_longer(cv2,
cols = -dec_criterion,
names_to = 'measure',
values_to = 'value')
cv2$model <- 2
Compare ROC curves:
ROC_results <- rbind(cv1, cv2)
ROC_results$model <- ifelse(ROC_results$model == 1,
"Partial", "Full")
ROC_results <- ROC_results %>%
pivot_wider(id_cols = c('dec_criterion', 'model'),
names_from = measure,
values_from = value)
ROC_results$model <- factor(ROC_results$model)
ggplot(data = ROC_results,
aes(x = fa,
y = hit,
group = model,
color = model,
fill = model)) +
ylab("Hit Rate (TPR)") +
xlab("FA Rate (FPR)") +
geom_point(size = 1) + geom_path(size = .1) +
geom_abline(intercept = 0, slope = 1, size = .5) +
ggtitle("ROC analysis for the Binomial Regression Model") +
theme_bw() +
theme(plot.title = element_text(hjust = .5)) +
theme_bw() +
theme(panel.border = element_blank()) +
theme(plot.title = element_text(hjust = .5, size = 20)) +
theme(legend.text = element_text(size = 20)) +
theme(legend.title = element_text(size = 20)) +
theme(legend.position = "top") +
theme(axis.title.x = element_text(size = 18)) +
theme(axis.title.y = element_text(size = 18)) +
theme(axis.text.x = element_text(size = 17)) +
theme(axis.text.y = element_text(size = 17))
What is a Decision Tree classifier? Let’s introduce the Decision Tree by an example before diving into theory in the next session. We will use the HR_comma_sep.csv
dataset again:
# - load HR_comma_sep.csv again
dataSet <- read.csv(paste0('_data/', 'HR_comma_sep.csv'),
header = T,
check.names = F,
stringsAsFactors = F)
Let’s split dataSet
into a training and test subsets:
# - Test and Train data:
ix <- rbinom(dim(dataSet)[1] , 1, .5)
table(ix)/sum(table(ix))
ix
0 1
0.4942329 0.5057671
train <- dataSet[ix == 1, ]
test <- dataSet[ix == 0,]
Train one Decision Tree on train
:
# - Base Model
classTree <- rpart(left ~ .,
data = train,
method = "class")
Visualize the model with prp()
:
prp(classTree,
cex = .8)
Decision Trees can easily overfit because of the intrinsinc complexity of the model. Pruning is one of the methods to prevent the Decision Tree for overfitting: we prune the tree by relying on the complexity parameter (cp) to discard the branches that were developed to fit potentially idiosyncratic information present in the data. The CP (complexity parameter) is used to control tree growth: if the cost of adding a variable is higher then the value of CP then tree growth stops.
The CP parameters has to do with an internal cross-validation procedure performed by {Rpart} during the training of a Decision Tree model (to be explained in our live session).
# - Base Model
classTree <- rpart(left ~ .,
data = train,
method = "class",
control = rpart.control(cp = 0))
# - Inspect model:
prp(classTree,
cex = .8)
# - Examine the complexity plot
cptable <- as.data.frame(classTree$cptable)
print(cptable)
plotcp(classTree)
The one with least cross-validated error (xerror) is the optimal value of CP.
cptable[which.min(cptable$xerror), ]
ROC analysis for the base model:
# - Base Model Accuracy
test$pred <- predict(classTree,
test,
type = "class")
# - silly, but I need to do this...
test$pred <- as.numeric(as.character(test$pred))
base_accuracy <- mean(test$pred == test$left)
print(paste0("Base model acc: ", base_accuracy))
[1] "Base model acc: 0.971266693646297"
# - Base Model ROC
test$hit <- ifelse(test$pred == 1 & test$left == 1, T, F)
test$FA <- ifelse(test$pred == 1 & test$left == 0, T, F)
hitRate <- sum(test$hit)/length(test$hit)
print(paste0("Base model Hit rate: ", hitRate))
[1] "Base model Hit rate: 0.222851746931067"
FARate <- sum(test$FA)/length(test$FA)
print(paste0("Base model FA rate: ", FARate))
[1] "Base model FA rate: 0.0078240928099285"
test$miss <- ifelse(test$pred == 0 & test$left == 1, T, F)
missRate <- sum(test$miss)/length(test$miss)
print(paste0("Base model Miss rate: ", missRate))
[1] "Base model Miss rate: 0.0209092135437745"
# - Prune the classTree based on the optimal cp value
optimal_cp <- cptable$CP[which.min(cptable$xerror)]
classTree_prunned <- prune(classTree,
cp = optimal_cp)
prp(classTree_prunned,
cex = .75)
# - The accuracy of the pruned tree
test$pred <- predict(classTree_prunned,
test,
type = "class")
accuracy_postprun <- mean(test$pred == test$left)
print(paste0("Pruned model acc: ", accuracy_postprun))
[1] "Pruned model acc: 0.971131795494402"
# - Pruned Model ROC
test$hit <- ifelse(test$pred == 1 & test$left == 1, T, F)
test$FA <- ifelse(test$pred == 1 & test$left == 0, T, F)
hitRate <- sum(test$hit)/length(test$hit)
print(paste0("Pruned Hit rate: ", hitRate))
[1] "Pruned Hit rate: 0.219614191285579"
FARate <- sum(test$FA)/length(test$FA)
print(paste0("Pruned FA rate: ", FARate))
[1] "Pruned FA rate: 0.00472143531633617"
test$miss <- ifelse(test$pred == 0 & test$left == 1, T, F)
missRate <- sum(test$miss)/length(test$miss)
print(paste0("Pruned Miss rate: ", missRate))
[1] "Pruned Miss rate: 0.0241467691892621"
test$CR <- ifelse(test$pred == 0 & test$left == 0, T, F)
CRRate <- sum(test$CR)/length(test$CR)
print(paste0("Pruned CR rate: ", CRRate))
[1] "Pruned CR rate: 0.751517604208822"
For pruning with {rpart} Decison Trees in R, see the following Stack Overflow discussion: Selecting cp value for decision tree pruning using rpart.
Goran S. Milovanović
DataKolektiv, 2020/21
contact: goran.milovanovic@datakolektiv.com
License: GPLv3 This Notebook is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This Notebook is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this Notebook. If not, see http://www.gnu.org/licenses/.