Feedback should be send to goran.milovanovic@datakolektiv.com
. These notebooks accompany the Intro to Data Science: Non-Technical Background course 2020/21.
In this session we will introduce an additional, important method to control for overfitting in regression models: regularization. Both L1 (Lasso Regression) and L2 (Ridge Regression) norms will be discussed and {glmnet} R package will be used to fit regularized linear models. This session wraps-up our journey in the field of statistical machine learning in the scope of this course. In the second part of the session we will learn how to run R code in parallel w. {snowfall}. Then we dive into the first out of three case studies and begin a synthesis of everything that we have learned thus far.
install.packages('glmnet')
install.packages('snowfall')
Grab the HR_comma_sep.csv
dataset from the Kaggle and place it in your _data
directory for this session. We will also use the Boston Housing Dataset: BostonHousing.csv
dataDir <- paste0(getwd(), "/_data/")
library(tidyverse)
library(data.table)
library(glmnet)
Regularization is used for several purposes:
Consider the following:
\[Loss = \sum_{i=1}^{n}{\big(y_i-\hat{y_i}\big)^2}\]
It is just the ordinary sum of squares, or the loss (or cost) function of the Linear Regression Model. The parameters \(\beta\) of the Linear Model are estimate by minimizing the above quantity - the model’s cost function.
Now, consider the following formulation of the cost function that includes the penalty term:
\[Loss_{L2} = \sum_{i=1}^{n}{\big(y_i-\hat{y_i}\big)^2}+\lambda\sum_{j=1}^{p}\beta_j^2\]
\(\lambda\sum_{j=1}^{p}\beta_j^2\) is the penalty term: it increases the value of the cost function by a factor determined by the sum of the squared model coefficients. In the expression above \(p\) stands for the number of model parameters (coefficients, in the case of Linear Regression) and \(\beta\) are the coefficients themselves. The above expression represents the \(L2\) regularization for the Linear Model, also known as Ridge Regression. If \(\lambda\) is zero we get back to the Ordinary Least Squares model that we have already learned about. If \(\lambda\) is too large, it will add too much weight to the cost function which will lead to underfitting. The underlying idea is to penalize large coefficients: it regularizes the coefficients so that if they take large values the cost function is penalized. In effect, the Ridge regression shrinks the coefficients so to reduce the model complexity and reduces the effects multicollinearity.
Consider now the following modification of the cost function:
\[Loss_{L1} = \sum_{i=1}^{n}{\big(y_i-\hat{y_i}\big)^2}+\lambda\sum_{j=1}^{p}|\beta_j|\]
The above cost function is based on a weighted sum of the absolute values of the model coefficients. This approach is known as \(L1\) or Lasso Regression.
The most important difference between Ridge and Lasso regularization is that in Lasso some coefficients can be shrinked to zero and get completely eliminated from the model. Thus, Lasso regularization does not only prevent overfitting but also performs feature selection - a very handy thing when considering large linear models.
Finally, consider the following formulation of a loss (or cost) function:
\[Loss_{ElasticNet} = \frac{\sum_{i=1}^{n}(y_i-\hat{y_i})^2}{2n}+\lambda(\frac{1-\alpha}{2}\sum_{j=1}^{m}\hat{\beta_j}^2+\alpha\sum_{j=1}^{m}|\hat{\beta_j}|)\]
In the expression above - the Elastic Net regression - \(\alpha\) is the mixing parameter between Lasso (\(\alpha=1\)) and Ridge (\(\alpha=0\)) approach. While the Ridge regression shrinks the coefficients of correlated predictors towards each other, the Lasso regression tends to pick one of them and decrease the others. The Elastic Net penalty is a mixture of the two approaches controled by \(\alpha\).
We will use {glmnet} to fit regularized linear regression models in R. We begin with Lasso in the Boston Housing dataset:
library(glmnet)
dataSet <- read.csv(paste0('_data/', 'BostonHousing.csv'),
header = T,
check.names = F,
stringsAsFactors = F)
head(dataSet)
Here goes the \(L1\) (Lasso) regularization in a Linear Model:
predictors <- dataSet %>%
dplyr::select(-medv)
l1_Model <- glmnet(as.matrix(predictors),
y = dataSet$medv)
Pay attention that the formulation of the Linear Model is somewhat different in {glmnet} in comparison to what we have learned thus far on the usage of the R formula interface (in lm()
, for example): we need to split the predictors
from the the outcome which is passed in the y
argument as in the above chunk. The glmnet
function really solves the Elastic Net regression using \(\alpha=1\) as its default: the Lasso (L1) regression.
glmnet()
returns a set of models for us:
dim(as.matrix(l1_Model$beta))
[1] 13 76
as.matrix(l1_Model$beta)
s0 s1 s2 s3 s4 s5 s6 s7 s8 s9 s10
crim 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
zn 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
indus 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
chas 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
nox 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
rm 0 0.00000000 0.1278413 0.5694424 0.9714620 1.3377669 1.671530 1.9756431 2.2527393 2.4795400 2.6603223
age 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
dis 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
rad 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
tax 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
ptratio 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 -0.0401687 -0.1192839
b 0 0.00000000 0.0000000 0.0000000 0.0000000 0.0000000 0.000000 0.0000000 0.0000000 0.0000000 0.0000000
lstat 0 -0.08439977 -0.1535809 -0.1969814 -0.2365474 -0.2725985 -0.305447 -0.3353772 -0.3626486 -0.3844930 -0.4011381
s11 s12 s13 s14 s15 s16 s17 s18 s19 s20
crim 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
zn 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
indus 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
chas 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
nox 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
rm 2.8251240 2.9752853 3.1119189 3.2365996 3.3502065 3.4537210 3.5480395 3.6339790 3.7260255824 3.816165372
age 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
dis 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
rad 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
tax 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000000 0.000000000
ptratio -0.1913698 -0.2570519 -0.3168659 -0.3713991 -0.4210879 -0.4663624 -0.5076149 -0.5452026 -0.5774794639 -0.606026131
b 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0000000 0.0006479786 0.001518042
lstat -0.4162997 -0.4301145 -0.4427170 -0.4541851 -0.4646341 -0.4741550 -0.4828300 -0.4907343 -0.4942974836 -0.495954410
s21 s22 s23 s24 s25 s26 s27 s28 s29
crim 0.000000000 0.000000000 0.00000000 0.000000000 -0.001465439 -0.004750339 -0.00877571 -0.013309174 -0.016820727
zn 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000
indus 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000
chas 0.128706012 0.412363298 0.67067358 0.906034673 1.119327267 1.313492706 1.45712935 1.562819143 1.675294130
nox 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 -0.708755104
rm 3.896100545 3.962718568 4.02397156 4.079788153 4.130188268 4.181880613 4.21733517 4.237264635 4.251922442
age 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000
dis 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 -0.02942942 -0.080074542 -0.149254585
rad 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000
tax 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000 0.000000000 0.00000000 0.000000000 0.000000000
ptratio -0.630523516 -0.650724354 -0.66916199 -0.685961897 -0.700311753 -0.712574008 -0.72558922 -0.738845222 -0.754242569
b 0.002301817 0.002993802 0.00362614 0.004202317 0.004689042 0.005099843 0.00552638 0.005949482 0.006241931
lstat -0.497563634 -0.499430302 -0.50108591 -0.502594040 -0.503481512 -0.503117204 -0.50684543 -0.513735694 -0.517173282
s30 s31 s32 s33 s34 s35 s36 s37 s38
crim -0.019439916 -0.021826808 -0.024001671 -0.026508930 -0.029579467 -0.032386551 -0.034943926 -0.037274105 -0.040195642
zn 0.000000000 0.000000000 0.000000000 0.001530975 0.005117086 0.008396934 0.011385002 0.014107613 0.016474798
indus 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 -0.003689017
chas 1.791658202 1.897718334 1.994355941 2.084417208 2.168428935 2.244958387 2.314692508 2.378231692 2.438397212
nox -2.085387000 -3.341293682 -4.485651769 -5.537820918 -6.478368167 -7.332177446 -8.110065854 -8.818849049 -9.442384307
rm 4.259590139 4.266279454 4.272373008 4.271661527 4.261458164 4.251522557 4.242473848 4.234229178 4.218091225
age 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
dis -0.237030301 -0.317122207 -0.390100181 -0.467615327 -0.553162685 -0.631213118 -0.702323093 -0.767115778 -0.829027660
rad 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.002518770
tax 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
ptratio -0.771520796 -0.787265491 -0.801611503 -0.810807359 -0.813351925 -0.815597187 -0.817644211 -0.819509409 -0.822487965
b 0.006407868 0.006558204 0.006695178 0.006827001 0.006958142 0.007076584 0.007184524 0.007282875 0.007384673
lstat -0.517661470 -0.518129464 -0.518555953 -0.518896378 -0.519363373 -0.519864563 -0.520321166 -0.520737188 -0.521141737
s39 s40 s41 s42 s43 s44 s45 s46
crim -0.046237349 -0.051377303 -0.056304773 -0.060894515 -0.065039903 -0.068811070 -7.228841e-02 -0.075482988
zn 0.018268212 0.020223247 0.022498680 0.024615792 0.026538212 0.028287918 2.989001e-02 0.031316408
indus -0.009639918 -0.013065442 -0.010005006 -0.007297361 -0.004900772 -0.002722959 -6.609862e-04 0.000000000
chas 2.491724181 2.530033156 2.544478370 2.557346555 2.569388357 2.580413325 2.590094e+00 2.600826187
nox -10.310559367 -10.962782637 -11.552370435 -12.099449201 -12.595647203 -13.047554986 -1.346160e+01 -13.822312447
rm 4.191928657 4.163603455 4.133062337 4.104769537 4.079222023 4.055992084 4.034553e+00 4.014242371
age 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000e+00 0.000000000
dis -0.888732000 -0.941973485 -0.988853860 -1.032364998 -1.071991058 -1.108078809 -1.140984e+00 -1.171827121
rad 0.014467604 0.028808391 0.053040699 0.075474154 0.095699233 0.114095544 1.311007e-01 0.146267904
tax 0.000000000 -0.000320548 -0.001373734 -0.002346088 -0.003222644 -0.004019917 -4.757059e-03 -0.005396298
ptratio -0.837298632 -0.847897521 -0.857096797 -0.865536672 -0.873139214 -0.880058856 -8.864579e-01 -0.892024291
b 0.007575262 0.007735573 0.007876045 0.008004311 0.008120518 0.008226330 8.323480e-03 0.008410550
lstat -0.520866936 -0.521327300 -0.521622410 -0.521820562 -0.521999136 -0.522160914 -5.223123e-01 -0.522360535
s47 s48 s49 s50 s51 s52 s53 s54
crim -0.078366680 -0.081032219 -0.083417938 -0.085633397 -0.087610024 -0.089451980 -0.091088609 -0.092620214
zn 0.032600338 0.033776311 0.034838515 0.035815478 0.036695023 0.037506924 0.038234464 0.038909419
indus 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
chas 2.611564574 2.621102886 2.630086779 2.637986594 2.645483035 2.652020942 2.658283708 2.663692147
nox -14.137164919 -14.424681243 -14.685749763 -14.924469002 -15.141064439 -15.339292006 -15.518854904 -15.683480443
rm 3.995671804 3.978450381 3.963110388 3.948789628 3.936096872 3.924182781 3.913687750 3.903773622
age 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
dis -1.200249778 -1.226220504 -1.249766769 -1.271335668 -1.290853295 -1.308770114 -1.324936435 -1.339822837
rad 0.159679592 0.172093205 0.183187538 0.193506383 0.202697632 0.211277976 0.218888599 0.226024047
tax -0.005954901 -0.006471554 -0.006933612 -0.007363076 -0.007745828 -0.008102936 -0.008419827 -0.008716805
ptratio -0.896802232 -0.901213017 -0.905175711 -0.908840187 -0.912130301 -0.915174988 -0.917906736 -0.920436045
b 0.008488203 0.008559462 0.008623886 0.008683074 0.008736552 0.008785715 0.008830110 0.008870944
lstat -0.522376782 -0.522393577 -0.522406130 -0.522420631 -0.522430050 -0.522442620 -0.522450040 -0.522460947
s55 s56 s57 s58 s59 s60 s61 s62
crim -0.093973998 -0.095247648 -0.096365803 -0.097424931 -0.098346337 -0.099226869 -0.099983351 -0.100714832
zn 0.039510386 0.040071698 0.040567190 0.041034124 0.041441612 0.041830020 0.042163911 0.042486737
indus 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
chas 2.668932088 2.673404963 2.677797325 2.681496880 2.685187736 2.688250324 2.691362170 2.693903097
nox -15.832157974 -15.968911958 -16.091776269 -16.205420450 -16.306645226 -16.401122000 -16.484121376 -16.562664327
rm 3.895103346 3.886852696 3.879698087 3.872832939 3.866938877 3.861229965 3.856387746 3.851646315
age 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000 0.000000000
dis -1.353199455 -1.365571281 -1.376623666 -1.386908238 -1.396021614 -1.404571749 -1.412063666 -1.419168850
rad 0.232320994 0.238255004 0.243459038 0.248393350 0.252686500 0.256788019 0.260319555 0.263725830
tax -0.008978974 -0.009225957 -0.009442628 -0.009648008 -0.009826799 -0.009997514 -0.010144690 -0.010286456
ptratio -0.922703692 -0.924804463 -0.926685654 -0.928430214 -0.929988658 -0.931437290 -0.932724941 -0.933927773
b 0.008907803 0.008941714 0.008972315 0.009000473 0.009025875 0.009049252 0.009070332 0.009089735
lstat -0.522467523 -0.522476864 -0.522483752 -0.522491506 -0.522499839 -0.522505968 -0.522516834 -0.522521473
s63 s64 s65 s66 s67 s68 s69 s70
crim -0.101393479 -0.101959353 -0.102513087 -1.030082e-01 -0.10342081 -0.103808651 -0.104132869 -0.104455334
zn 0.042784125 0.043032597 0.043276692 4.350918e-02 0.04376785 0.043992065 0.044184458 0.044370196
indus 0.000000000 0.000000000 0.000000000 3.012702e-04 0.00189287 0.003524133 0.004946237 0.006304191
chas 2.696136602 2.698519469 2.700442679 2.701696e+00 2.70075820 2.699788652 2.699226636 2.698467386
nox -16.634667957 -16.696340287 -16.755818227 -1.681424e+01 -16.88367428 -16.953707131 -17.015390844 -17.075226985
rm 3.847231871 3.843603202 3.840016222 3.836972e+00 3.83495875 3.833180781 3.831857527 3.830427989
age 0.000000000 0.000000000 0.000000000 0.000000e+00 0.00000000 0.000000000 0.000000000 0.000000000
dis -1.425687466 -1.431281553 -1.436656218 -1.441368e+00 -1.44482611 -1.447744714 -1.450300595 -1.452751779
rad 0.266888579 0.269539160 0.272117316 2.745231e-01 0.27705039 0.279513680 0.281571774 0.283610857
tax -0.010418000 -0.010528583 -0.010635851 -1.074078e-02 -0.01087205 -0.010998316 -0.011104901 -0.011209046
ptratio -0.935038260 -0.936008402 -0.936920066 -9.377945e-01 -0.93881745 -0.939945869 -0.940920095 -0.941874453
b 0.009107537 0.009123506 0.009138192 9.151675e-03 0.00916544 0.009178978 0.009190891 0.009202124
lstat -0.522525376 -0.522537970 -0.522541068 -5.225658e-01 -0.52270020 -0.522817656 -0.522911887 -0.523000031
s71 s72 s73 s74 s75
crim -0.104719398 -0.104983546 -0.105238892 -0.105437872 -0.10563564
zn 0.044527873 0.044681863 0.044825720 0.044943263 0.04505959
indus 0.007458193 0.008579712 0.009625910 0.010477671 0.01132070
chas 2.698073139 2.697463261 2.696800526 2.696540388 2.69609610
nox -17.125645158 -17.174963265 -17.221436713 -17.258979779 -17.29604866
rm 3.829379107 3.828210101 3.827051956 3.826281634 3.82541641
age 0.000000000 0.000000000 0.000000000 0.000000000 0.00000000
dis -1.454866978 -1.456900793 -1.458793919 -1.460386386 -1.46192454
rad 0.285276396 0.286950451 0.288555789 0.289797548 0.29105088
tax -0.011295472 -0.011381204 -0.011462622 -0.011526962 -0.01159130
ptratio -0.942664365 -0.943445664 -0.944193251 -0.944783173 -0.94536634
b 0.009211899 0.009221159 0.009229812 0.009237192 0.00924415
lstat -0.523075319 -0.523149200 -0.523214971 -0.523268957 -0.52332406
There are as many models estimated as there were different values of \(\lambda\) - the parameter that controls the penalization - that were tried out.
length(l1_Model$lambda)
[1] 76
The values of lambda tried out in \(L1\) Lasso regularization were:
l1_Model$lambda
[1] 6.777653645 6.175545575 5.626927127 5.127046430 4.671573757 4.256564020 3.878422604 3.533874230 3.219934583 2.933884470
[11] 2.673246260 2.435762430 2.219376009 2.022212762 1.842564953 1.678876561 1.529729795 1.393832816 1.270008551 1.157184491
[21] 1.054383411 0.960714894 0.875367631 0.797602383 0.726745586 0.662183511 0.603356953 0.549756384 0.500917542 0.456417409
[31] 0.415870544 0.378925751 0.345263032 0.314590816 0.286643435 0.261178822 0.237976415 0.216835246 0.197572201 0.180020431
[41] 0.164027912 0.149456124 0.136178854 0.124081100 0.113058077 0.103014309 0.093862802 0.085524289 0.077926547 0.071003768
[51] 0.064695989 0.058948575 0.053711746 0.048940143 0.044592435 0.040630966 0.037021423 0.033732542 0.030735836 0.028005349
[61] 0.025517431 0.023250533 0.021185020 0.019303001 0.017588175 0.016025690 0.014602012 0.013304810 0.012122847 0.011045887
[71] 0.010064601 0.009170489 0.008355808 0.007613501 0.006937139 0.006320863
Because the estimation in {glmnet} is done over a penalized likelihood function, the optimal model is the one with the minimal deviance (as in GLMs):
which.min(deviance(l1_Model))
[1] 76
l1_Model$lambda[76]
[1] 0.006320863
In {glmnet}, cross-validation across \(\lambda\) can be performed by cv.glmnet()
:
cv_l1Model <- cv.glmnet(x = as.matrix(predictors),
y = dataSet$medv,
nfolds = 10)
plot(cv_l1Model)
cv_l1Model$lambda
[1] 6.777653645 6.175545575 5.626927127 5.127046430 4.671573757 4.256564020 3.878422604 3.533874230 3.219934583 2.933884470
[11] 2.673246260 2.435762430 2.219376009 2.022212762 1.842564953 1.678876561 1.529729795 1.393832816 1.270008551 1.157184491
[21] 1.054383411 0.960714894 0.875367631 0.797602383 0.726745586 0.662183511 0.603356953 0.549756384 0.500917542 0.456417409
[31] 0.415870544 0.378925751 0.345263032 0.314590816 0.286643435 0.261178822 0.237976415 0.216835246 0.197572201 0.180020431
[41] 0.164027912 0.149456124 0.136178854 0.124081100 0.113058077 0.103014309 0.093862802 0.085524289 0.077926547 0.071003768
[51] 0.064695989 0.058948575 0.053711746 0.048940143 0.044592435 0.040630966 0.037021423 0.033732542 0.030735836 0.028005349
[61] 0.025517431 0.023250533 0.021185020 0.019303001 0.017588175 0.016025690 0.014602012 0.013304810 0.012122847 0.011045887
[71] 0.010064601 0.009170489 0.008355808 0.007613501 0.006937139 0.006320863
The mean cross-validated error across the tried out values of \(\lambda\) is:
cv_l1Model$cvm
[1] 84.64485 77.19867 70.42595 64.04553 58.47547 53.85140 50.01429 46.83039 44.17979 41.77450 39.44154 37.47906 35.85059
[14] 34.49895 33.37742 32.44686 31.67511 31.03853 30.50641 30.04296 29.64760 29.27726 28.89847 28.55098 28.26547 28.03096
[27] 27.84064 27.65532 27.40133 27.11650 26.77961 26.45077 26.17824 25.93905 25.72151 25.51628 25.33942 25.19726 25.06641
[40] 24.91258 24.73175 24.52345 24.31860 24.14642 24.00223 23.88135 23.78013 23.69518 23.62390 23.56374 23.51367 23.47127
[53] 23.43585 23.40711 23.38283 23.36281 23.34696 23.33677 23.32809 23.32351 23.31966 23.31730 23.31541 23.31427 23.31363
[66] 23.31401 23.31444 23.31571 23.31670 23.31760 23.31829 23.31904 23.31981 23.32041 23.32178 23.32314
And the respective value of \(\lambda\) is:
cv_l1Model$lambda.min
[1] 0.01758818
So the optimal Lasso model would be:
l1Optim_Model <- glmnet(predictors,
y = dataSet$medv,
lambda = cv_l1Model$lambda.min)
predicted_medv <- predict(l1Optim_Model,
newx = as.matrix(predictors))
predictFrame <- data.frame(predicted_medv = predicted_medv,
observed_medv = dataSet$medv)
ggplot(data = predictFrame,
aes(x = predicted_medv,
y = observed_medv)) +
geom_smooth(method = "lm", size = .25, color = "red") +
geom_point(size = 1.5, color = "black") +
geom_point(size = .75, color = "white") +
ggtitle("LASSO: Observed vs Predicted\nBoston Housing Dataset") +
theme_bw() +
theme(panel.border = element_blank()) +
theme(plot.title = element_text(hjust = .5, size = 8))
The \(L2\) or Ridge regression is obtained from {glmnet} by setting alpha
to zero:
predictors <- dataSet %>%
dplyr::select(-medv)
l2_Model <- glmnet(predictors,
y = dataSet$medv,
alpha = 0)
l2_Model$lambda
[1] 6777.6536446 6175.5455754 5626.9271274 5127.0464303 4671.5737566 4256.5640198 3878.4226042 3533.8742298 3219.9345832
[10] 2933.8844696 2673.2462597 2435.7624300 2219.3760091 2022.2127615 1842.5649534 1678.8765614 1529.7297950 1393.8328162
[19] 1270.0085505 1157.1844913 1054.3834105 960.7148944 875.3676311 797.6023834 726.7455860 662.1835112 603.3569532
[28] 549.7563844 500.9175425 456.4174086 415.8705440 378.9257511 345.2630318 314.5908156 286.6434347 261.1788220
[37] 237.9764153 216.8352465 197.5722008 180.0204310 164.0279121 149.4561245 136.1788543 124.0811002 113.0580773
[46] 103.0143093 93.8628020 85.5242894 77.9265472 71.0037676 64.6959885 58.9485752 53.7117463 48.9401428
[55] 44.5924354 40.6309663 37.0214233 33.7325421 30.7358360 28.0053490 25.5174310 23.2505328 21.1850195
[64] 19.3030008 17.5881754 16.0256904 14.6020122 13.3048097 12.1228471 11.0458868 10.0646006 9.1704892
[73] 8.3558082 7.6135013 6.9371388 6.3208625 5.7593345 5.2476911 4.7815007 4.3567253 3.9696859
[82] 3.6170299 3.2957030 3.0029218 2.7361505 2.4930784 2.2716002 2.0697975 1.8859224 1.7183823
[91] 1.5657259 1.4266311 1.2998932 1.1844142 1.0791941 0.9833215 0.8959659 0.8163708 0.7438467
[100] 0.6777654
Remember: The alpha
parameter (has a default value of 1
) is the Elastic Net mixing parameter: the mixing parameter between Ridge (\(\alpha=0\)) and Lasso (\(\alpha=1\)).
The {glmnet} package will perform an automatic cross-validation across \(\lambda\), as demonstrated, but you will need to develop your own cross-validation for the mixing parameter \(\alpha\).
Let’s clean up our environment first, set the dataDir
again, and import the {snowfall} library:
rm(list = ls())
dataDir <- paste0(getwd(), "/_data/")
library(snowfall)
library(randomForest)
We choose to parallelize R code with {snowfall} because it presents arguably the simplest interface for parallel computation R.
We are getting back to the HR_comma_sep.csv
dataset and the Random Forest model:
dataSet <- read.csv(paste0(getwd(), "/_data/HR_comma_sep.csv"),
header = T,
check.names = 1,
stringsAsFactors = F)
glimpse(dataSet)
Rows: 14,999
Columns: 10
$ satisfaction_level <dbl> 0.38, 0.80, 0.11, 0.72, 0.37, 0.41, 0.10, 0.92, 0.89, 0.42, 0.45, 0.11, 0.84, 0.41, 0.36, 0.38...
$ last_evaluation <dbl> 0.53, 0.86, 0.88, 0.87, 0.52, 0.50, 0.77, 0.85, 1.00, 0.53, 0.54, 0.81, 0.92, 0.55, 0.56, 0.54...
$ number_project <int> 2, 5, 7, 5, 2, 2, 6, 5, 5, 2, 2, 6, 4, 2, 2, 2, 2, 4, 2, 5, 6, 2, 6, 2, 2, 5, 4, 2, 2, 2, 6, 2...
$ average_montly_hours <int> 157, 262, 272, 223, 159, 153, 247, 259, 224, 142, 135, 305, 234, 148, 137, 143, 160, 255, 160,...
$ time_spend_company <int> 3, 6, 4, 5, 3, 3, 4, 5, 5, 3, 3, 4, 5, 3, 3, 3, 3, 6, 3, 5, 4, 3, 4, 3, 3, 5, 5, 3, 3, 3, 4, 3...
$ Work_accident <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
$ left <int> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1...
$ promotion_last_5years <int> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
$ sales <chr> "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sales", "sal...
$ salary <chr> "low", "medium", "medium", "low", "low", "low", "low", "low", "low", "low", "low", "low", "low...
We will immediately add the ix
column to the dataset to prepare it for a 5-fold cross-validation:
dataSet$ix <- sample(1:5, dim(dataSet)[1], replace = T)
table(dataSet$ix)
1 2 3 4 5
3081 2957 3000 2992 2969
Recall that in our previous Session 21 it took approximately 17 minutes to cross-validate a set of Random Forest models for this dataset with the following ranges of parameters ntree
and mtry
in a 5-fold CV:
ntree <- seq(250, 1000, by = 250)
mtry <- 1:(dim(dataSet)[2]-2)
folds <- 1:5
Now we want to make the CV procedure way more efficient by growing many Random Forest models in parallel. This is how we approach the problem in a typical {snowfall} workflow:
cv_design <- expand.grid(ntree, mtry, folds)
colnames(cv_design) <- c('ntree', 'mtry', 'fold')
head(cv_design)
The expand.grid()
function is used to combine several factors in order of fast they change in a design matrix. This time it helped us planned our CV procedure by efficiently combining all levels of ntree
, mtry
, and fold
. Each of the 180 rows in the cv_design
data.frame describes exactly one Random Forest model that we need to grow, each model described by a particular value of ntree
, mtry
and the fold
that will be used as a test dataset. Now we want to turn cv_design
into a list:
cv_design <- apply(cv_design, 1, function(x) {
list(ntree = x[1],
mtry = x[2],
fold = x[3])
})
Now we can do things such as:
cv_design[[1]]
$ntree
ntree
250
$mtry
mtry
1
$fold
fold
1
and:
cv_design[[1]]$ntree
ntree
250
In {snowfall}, we use a special family of parallelized *apply()
functions. But before we demonstrate that, we need a bit of a preparation. We will now turn your computer into a cluster (!). All modern CPUs have mulitple threads that can be used as separate logical units for computation. In this approach, we will make each thread available on your system work on a separate Random Forest model in parallel, thus speeding up the CV procedure.
The package {parallel}
is present by default in your R installation. Call parallel::detectCores()
to detect the number of threads that you can use:
parallel::detectCores()
[1] 8
Pronounce parallel::detectCores() - 1
to be a numeric numCores
:
numCores <- parallel::detectCores() - 1
numCores
[1] 7
Now we need to initialize a cluster with numCores
workers:
sfInit(parallel = TRUE, cpus = numCores)
R Version: R version 4.0.0 (2020-04-24)
Next, we need to export the data and the functions that we will use to the cluster:
sfExport('dataSet')
sfLibrary(dplyr)
Library dplyr loaded.
sfLibrary(randomForest)
Library randomForest loaded.
Finally, we need to code our CV procedure to run in parallel. Here is how we can do that:
tstart <- Sys.time()
rfModels <- sfClusterApplyLB(cv_design, function(x) {
# - pick up parameters from x
ntree <- x$ntree
mtry <- x$mtry
fold <- x$fold
# - split training and test sets
testIx <- fold
trainIx <- setdiff(1:5, testIx)
trainSet <- dataSet %>%
dplyr::filter(ix %in% trainIx) %>%
dplyr::select(-ix)
testSet <- dataSet %>%
dplyr::filter(ix %in% testIx) %>%
dplyr::select(-ix)
# - `left` to factor for classification w. randomForest()
trainSet$left <- as.factor(trainSet$left)
testSet$left <- as.factor(testSet$left)
# - Random Forest:
model <- randomForest::randomForest(formula = left ~ .,
data = trainSet,
ntree = ntree,
mtry = mtry
)
# - ROC analysis:
predictions <- predict(model,
newdata = testSet)
hit <- sum(ifelse(predictions == 1 & testSet$left == 1, 1, 0))
hit <- hit/sum(testSet$left == 1)
fa <- sum(ifelse(predictions == 1 & testSet$left == 0, 1, 0))
fa <- fa/sum(testSet$left == 0)
acc <- sum(predictions == testSet$left)
acc <- acc/length(testSet$left)
# - Output:
return(
data.frame(ntree = ntree,
mtry = mtry,
fold = fold,
hit = hit,
fa = fa,
acc = acc)
)
})
# - collect all results
rfModels <- rbindlist(rfModels)
write.csv(rfModels,
paste0(getwd(), "/rfModels.csv"))
# - Report timing:
print(paste0("The estimation took: ",
difftime(Sys.time(), tstart, units = "mins"),
" minutes."))
[1] "The estimation took: 9.98846626679103 minutes."
Do not forget to shutdown the cluster:
sfStop()
And here are the average results across the values of ntree
and mtry
:
rfModels <- rfModels %>%
group_by(ntree, mtry) %>%
summarise(hit = mean(hit),
fa = mean(fa),
acc = mean(acc))
rfModels
rfModels$ntree <- factor(rfModels$ntree)
rfModels$mtry <- factor(rfModels$mtry)
ggplot(data = rfModels,
aes(x = mtry,
y = acc,
group = ntree,
color = ntree,
fill = ntree,
label = round(acc, 2))
) +
geom_path(size = .25) +
geom_point(size = 1.5) +
ggtitle("Random Forests CV: Accuracy") +
theme_bw() +
theme(panel.border = element_blank()) +
theme(plot.title = element_text(hjust = .5, size = 8))
ggplot(data = rfModels,
aes(x = mtry,
y = hit,
group = ntree,
color = ntree,
fill = ntree,
label = round(acc, 2))
) +
geom_path(size = .25) +
geom_point(size = 1.5) +
ggtitle("Random Forests CV: Hit Rate") +
theme_bw() +
theme(panel.border = element_blank()) +
theme(plot.title = element_text(hjust = .5, size = 8))
ggplot(data = rfModels,
aes(x = mtry,
y = fa,
group = ntree,
color = ntree,
fill = ntree,
label = round(acc, 2))
) +
geom_path(size = .25) +
geom_point(size = 1.5) +
ggtitle("Random Forests CV: FA Rate") +
theme_bw() +
theme(panel.border = element_blank()) +
theme(plot.title = element_text(hjust = .5, size = 8))
L1 and L2 Regularization Methods, Anuja Nagpal, Towards Data Science
L2 and L1 Regularization in Machine Learning, Neelam Tyagi, analyticsteps
An Introduction to glmnet, Trevor Hastie, Junyang Qian, Kenneth Tay, February 22, 2021
Developing parallel programs using snowfall, Jochen Knaus, 2010-03-04
Goran S. Milovanović
DataKolektiv, 2020/21
contact: goran.milovanovic@datakolektiv.com
License: GPLv3 This Notebook is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This Notebook is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this Notebook. If not, see http://www.gnu.org/licenses/.