1 Introduction

DALEX explainers may be used to see what type of relation the model can learn / what the model has learned.

If we know the ground truth then we may verify model capability of learning particular types of relations.

2 Simulated data

Let’s simulate a model response as a function of four arguments

\[ (2x_1-1)^2 + sin(10 x_2) + x_3^{6} + (2 x_4 - 1) + |2x_5-1| \]

set.seed(13)
N <- 250
X1 <- runif(N)
X2 <- runif(N)
X3 <- runif(N)
X4 <- runif(N)
X5 <- runif(N)

f <- function(x1, x2, x3, x4, x5) {
  ((x1-0.5)*2)^2-0.5 + sin(x2*10) + x3^6 + (x4-0.5)*2 + abs(2*x5-1) 
}
y <- f(X1, X2, X3, X4, X5)

3 Model fits

Let’s compare four models: fandom forest, svm, lm and the ground truth.

library(randomForest)
library(DALEX)
library(e1071)
library(rms)

df <- data.frame(y, X1, X2, X3, X4, X5)

model_rf <- randomForest(y~., df)
model_svm <- svm(y ~ ., df)
model_lm <- lm(y ~ ., df)

# thanks to https://github.com/pbiecek/DALEX/issues/24
## important setup step required for use of rms functions
dd <- datadist(df)
options(datadist="dd")
## add rcs terms to linear model
## this is a very convenient, objective way to account for non-linearity
## still a "linear" model because terms are linear combinations (additive)
model_rms <- ols(y ~ rcs(X1) + rcs(X2) + rcs(X3) + rcs(X4) + rcs(X5), df)

ex_rf <- explain(model_rf, data = df, y = df$y)
ex_svm <- explain(model_svm, data = df, y = df$y)
ex_lm <- explain(model_lm, data = df, y = df$y)
ex_rms <- explain(model_rms, label = "rms", data = df, y = df$y)
ex_tr <- explain(NULL, data = df[,-1], 
                 predict_function = function(m, x) f(x[,1], x[,2], x[,3], x[,4], x[,5]), 
                 label = "True Model")

4 Explainers

For X1 we want to see (2*x1 - 1)^2.

The linear model cannot guess the relation without prior preprocessing, the random forest is seeing something but the closest bet is from svm models.

library(ggplot2)
plot(model_profile(ex_rf, "X1"),
     model_profile(ex_svm, "X1"),
     model_profile(ex_lm, "X1"),
     model_profile(ex_rms, "X1"),
     model_profile(ex_tr, "X1")) +
  ggtitle("Responses for X1. Truth: y ~ (2*x1 - 1)^2")

For X2 we want to see sin(10 * x2).

The random forest guesses the shape, svm is not that elastic, the linear model does not see anything.

plot(model_profile(ex_rf, "X2"),
     model_profile(ex_svm, "X2"),
     model_profile(ex_lm, "X2"),
     model_profile(ex_rms, "X2"),
     model_profile(ex_tr, "X2")) +
  ggtitle("Responses for X2. Truth: y ~ sin(10 * x2)")

For X3 we want to see x3^6.

The random forest is still able to guesses the shape, svm and linear are close.

plot(model_profile(ex_rf, "X3"),
     model_profile(ex_svm, "X3"),
     model_profile(ex_lm, "X3"),
     model_profile(ex_rms, "X3"),
     model_profile(ex_tr, "X3")) +
  ggtitle("Responses for X3. Truth: y ~ x3^6")

For X4 we want to see 2 x4 - 1.

The linear model is doing the best job (as expected), svm are still pretty good, random forest model is more biased towards the mean.

plot(model_profile(ex_rf, "X4"),
     model_profile(ex_svm, "X4"),
     model_profile(ex_lm, "X4"),
     model_profile(ex_rms, "X4"),
     model_profile(ex_tr, "X4")) +
  ggtitle("Responses for X4. Truth: y ~ (2 * x4 - 1)")

For X5 we want to see |2 x5 - 1|.

All models except the linear one are guessing the shape.

plot(model_profile(ex_rf, "X5"),
     model_profile(ex_svm, "X5"),
     model_profile(ex_lm, "X5"),
     model_profile(ex_rms, "X5"),
     model_profile(ex_tr, "X5")) +
  ggtitle("Responses for X5. Truth: y ~ |2 * x5 - 1|")

sessionInfo()
## R version 3.6.3 (2020-02-29)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 10 x64 (build 18363)
## 
## Matrix products: default
## 
## locale:
## [1] LC_COLLATE=Polish_Poland.1250  LC_CTYPE=Polish_Poland.1250   
## [3] LC_MONETARY=Polish_Poland.1250 LC_NUMERIC=C                  
## [5] LC_TIME=Polish_Poland.1250    
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] rms_5.1-3           SparseM_1.78        Hmisc_4.3-1        
##  [4] ggplot2_3.3.0       Formula_1.2-3       survival_3.1-8     
##  [7] lattice_0.20-38     e1071_1.7-3         DALEX_2.0.1        
## [10] randomForest_4.6-14
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.4          mvtnorm_1.1-0       png_0.1-7          
##  [4] class_7.3-15        zoo_1.8-7           digest_0.6.25      
##  [7] R6_2.4.1            backports_1.1.5     acepack_1.4.1      
## [10] MatrixModels_0.4-1  evaluate_0.14       pillar_1.4.3       
## [13] rlang_0.4.6         multcomp_1.4-12     rstudioapi_0.11    
## [16] data.table_1.12.8   rpart_4.1-15        Matrix_1.2-18      
## [19] checkmate_2.0.0     rmarkdown_2.1       labeling_0.3       
## [22] splines_3.6.3       stringr_1.4.0       foreign_0.8-76     
## [25] htmlwidgets_1.5.1   munsell_0.5.0       compiler_3.6.3     
## [28] xfun_0.12           pkgconfig_2.0.3     base64enc_0.1-3    
## [31] htmltools_0.4.0     nnet_7.3-12         tidyselect_1.1.0   
## [34] tibble_2.1.3        gridExtra_2.3       htmlTable_1.13.3   
## [37] codetools_0.2-16    crayon_1.3.4        dplyr_1.0.0        
## [40] withr_2.1.2         MASS_7.3-51.5       grid_3.6.3         
## [43] nlme_3.1-144        polspline_1.1.17    gtable_0.3.0       
## [46] lifecycle_0.2.0     magrittr_1.5        scales_1.1.0       
## [49] stringi_1.4.6       farver_2.0.3        latticeExtra_0.6-29
## [52] generics_0.0.2      vctrs_0.3.1         sandwich_2.5-1     
## [55] TH.data_1.0-10      RColorBrewer_1.1-2  tools_3.6.3        
## [58] glue_1.3.2          purrr_0.3.3         ingredients_2.0    
## [61] jpeg_0.1-8.1        yaml_2.2.1          colorspace_1.4-1   
## [64] cluster_2.1.0       knitr_1.28          quantreg_5.54
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy