A Predictor
object holds any machine learning model (mlr
, caret
,
randomForest
, ...) and the data to be used for analyzing the model. The
interpretation methods in the iml
package need the machine learning model
to be wrapped in a Predictor
object.
A Predictor object is a container for the prediction model and the data. This ensures that the machine learning model can be analyzed in a robust way.
Note: In case of classification, the model should return one column per class with the class probability.
data
data.frame
Data object with the data for the model interpretation.
model
(any)
The machine learning model.
batch.size
numeric(1)
The number of rows to be input the model for prediction at once.
class
character(1)
The class column to be returned.
prediction.colnames
character
The column names of the predictions.
prediction.function
function
The function to predict newdata.
task
character(1)
The inferred prediction task: "classification"
or "regression"
.
new()
Create a Predictor object
Predictor$new(
model = NULL,
data = NULL,
predict.function = NULL,
y = NULL,
class = NULL,
type = NULL,
batch.size = 1000
)
model
any
The machine learning model. Recommended are models from mlr
and
caret
. Other machine learning with a S3 predict functions work as
well, but less robust (e.g. randomForest
).
data
data.frame
The data to be used for analyzing the prediction model. Allowed column
classes are: numeric, factor, integer, ordered and character
For some models the data can be extracted automatically.
Predictor$new()
throws an error when it can't extract the data
automatically.
predict.function
function
The function to predict newdata. Only needed if model
is not a model
from mlr
or caret
package. The first argument of predict.fun
has to
be the model, the second the newdata
:
y
character(1)
| numeric | factor
The target vector or
(preferably) the name of the target column in the data
argument.
Predictor tries to infer the target automatically from the model.
class
character(1)
The class column to be returned. You should use the column name of the
predicted class, e.g. class="setosa"
.
type
character(1)
)
This argument is passed to the prediction
function of the model. For regression models you usually don't have to
provide the type argument. The classic use case is to say type="prob"
for classification models. Consult the documentation of the machine
learning package you use to find which type options you have. If both
predict.fun
and type
are used, then type is passed as an argument
to predict.fun
.
batch.size
numeric(1)
The maximum number of rows to be input the model for prediction at once.
Currently only respected for FeatureImp, Partial and Interaction.
predict()
Predict new data with the machine learning model.
newdata
data.frame
Data to predict on.
library("mlr")
#> Loading required package: ParamHelpers
#>
#> Attaching package: ‘mlr’
#> The following object is masked from ‘package:yaImpute’:
#>
#> impute
#> The following object is masked from ‘package:generics’:
#>
#> train
task <- makeClassifTask(data = iris, target = "Species")
learner <- makeLearner("classif.rpart", minsplit = 7, predict.type = "prob")
mod.mlr <- train(learner, task)
mod <- Predictor$new(mod.mlr, data = iris)
mod$predict(iris[1:5, ])
#> setosa versicolor virginica
#> 1 1 0 0
#> 2 1 0 0
#> 3 1 0 0
#> 4 1 0 0
#> 5 1 0 0
mod <- Predictor$new(mod.mlr, data = iris, class = "setosa")
mod$predict(iris[1:5, ])
#> setosa
#> 1 1
#> 2 1
#> 3 1
#> 4 1
#> 5 1
library("randomForest")
rf <- randomForest(Species ~ ., data = iris, ntree = 20)
mod <- Predictor$new(rf, data = iris, type = "prob")
mod$predict(iris[50:55, ])
#> setosa versicolor virginica
#> 1 1 0.00 0.00
#> 2 0 0.95 0.05
#> 3 0 0.95 0.05
#> 4 0 0.95 0.05
#> 5 0 1.00 0.00
#> 6 0 1.00 0.00
# Feature importance needs the target vector, which needs to be supplied:
mod <- Predictor$new(rf, data = iris, y = "Species", type = "prob")