Skip to content

Apply a model to create different types of predictions. predict() can be used for all types of models and uses the "type" argument for more specificity.

Usage

# S3 method for model_fit
predict(object, new_data, type = NULL, opts = list(), ...)

# S3 method for model_fit
predict_raw(object, new_data, opts = list(), ...)

predict_raw(object, ...)

Arguments

object

An object of class model_fit.

new_data

A rectangular data object, such as a data frame.

type

A single character value or NULL. Possible values are "numeric", "class", "prob", "conf_int", "pred_int", "quantile", "time", "hazard", "survival", or "raw". When NULL, predict() will choose an appropriate value based on the model's mode.

opts

A list of optional arguments to the underlying predict function that will be used when type = "raw". The list should not include options for the model object or the new data being predicted.

...

Additional parsnip-related options, depending on the value of type. Arguments to the underlying model's prediction function cannot be passed here (use the opts argument instead). Possible arguments are:

  • interval: for type equal to "survival" or "quantile", should interval estimates be added, if available? Options are "none" and "confidence".

  • level: for type equal to "conf_int", "pred_int", or "survival", this is the parameter for the tail area of the intervals (e.g. confidence level for confidence intervals). Default value is 0.95.

  • std_error: for type equal to "conf_int" or "pred_int", add the standard error of fit or prediction (on the scale of the linear predictors). Default value is FALSE.

  • quantile: for type equal to quantile, the quantiles of the distribution. Default is (1:9)/10.

  • eval_time: for type equal to "survival" or "hazard", the time points at which the survival probability or hazard is estimated.

Value

With the exception of type = "raw", the result of predict.model_fit()

  • is a tibble

  • has as many rows as there are rows in new_data

  • has standardized column names, see below:

For type = "numeric", the tibble has a .pred column for a single outcome and .pred_Yname columns for a multivariate outcome.

For type = "class", the tibble has a .pred_class column.

For type = "prob", the tibble has .pred_classlevel columns.

For type = "conf_int" and type = "pred_int", the tibble has .pred_lower and .pred_upper columns with an attribute for the confidence level. In the case where intervals can be produces for class probabilities (or other non-scalar outputs), the columns are named .pred_lower_classlevel and so on.

For type = "quantile", the tibble has a .pred column, which is a list-column. Each list element contains a tibble with columns .pred and .quantile (and perhaps other columns).

For type = "time", the tibble has a .pred_time column.

For type = "survival", the tibble has a .pred column, which is a list-column. Each list element contains a tibble with columns .eval_time and .pred_survival (and perhaps other columns).

For type = "hazard", the tibble has a .pred column, which is a list-column. Each list element contains a tibble with columns .eval_time and .pred_hazard (and perhaps other columns).

Using type = "raw" with predict.model_fit() will return the unadulterated results of the prediction function.

In the case of Spark-based models, since table columns cannot contain dots, the same convention is used except 1) no dots appear in names and 2) vectors are never returned but type-specific prediction functions.

When the model fit failed and the error was captured, the predict() function will return the same structure as above but filled with missing values. This does not currently work for multivariate models.

Details

For type = NULL, predict() uses

  • type = "numeric" for regression models,

  • type = "class" for classification, and

  • type = "time" for censored regression.

Interval predictions

When using type = "conf_int" and type = "pred_int", the options level and std_error can be used. The latter is a logical for an extra column of standard error values (if available).

Censored regression predictions

For censored regression, a numeric vector for eval_time is required when survival or hazard probabilities are requested. The time values are required to be unique, finite, non-missing, and non-negative. The predict() functions will adjust the values to fit this specification by removing offending points (with a warning).

predict.model_fit() does not require the outcome to be present. For performance metrics on the predicted survival probability, inverse probability of censoring weights (IPCW) are required (see the tidymodels.org reference below). Those require the outcome and are thus not returned by predict(). They can be added via augment.model_fit() if new_data contains a column with the outcome as a Surv object.

Also, when type = "linear_pred", censored regression models will by default be formatted such that the linear predictor increases with time. This may have the opposite sign as what the underlying model's predict() method produces. Set increasing = FALSE to suppress this behavior.

Examples

library(dplyr)

lm_model <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = mtcars %>% dplyr::slice(11:32))

pred_cars <-
  mtcars %>%
  dplyr::slice(1:10) %>%
  dplyr::select(-mpg)

predict(lm_model, pred_cars)
#> # A tibble: 10 × 1
#>    .pred
#>    <dbl>
#>  1  23.4
#>  2  23.3
#>  3  27.6
#>  4  21.5
#>  5  17.6
#>  6  21.6
#>  7  13.9
#>  8  21.7
#>  9  25.6
#> 10  17.1

predict(
  lm_model,
  pred_cars,
  type = "conf_int",
  level = 0.90
)
#> # A tibble: 10 × 2
#>    .pred_lower .pred_upper
#>          <dbl>       <dbl>
#>  1       17.9         29.0
#>  2       18.1         28.5
#>  3       24.0         31.3
#>  4       17.5         25.6
#>  5       14.3         20.8
#>  6       17.0         26.2
#>  7        9.65        18.2
#>  8       16.2         27.2
#>  9       14.2         37.0
#> 10       11.5         22.7

predict(
  lm_model,
  pred_cars,
  type = "raw",
  opts = list(type = "terms")
)
#>                            cyl       disp         hp        drat
#> Mazda RX4         -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Mazda RX4 Wag     -0.001433177 -0.8113275  0.6303467 -0.06120265
#> Datsun 710        -0.009315653 -1.3336453  0.8557288 -0.05014798
#> Hornet 4 Drive    -0.001433177  0.1730406  0.6303467  0.12009386
#> Hornet Sportabout  0.006449298  1.1975870 -0.2314083  0.10461733
#> Valiant           -0.001433177 -0.1584303  0.6966356  0.19084372
#> Duster 360         0.006449298  1.1975870 -1.1594522  0.09135173
#> Merc 240D         -0.009315653 -0.9449204  1.2667197 -0.01477305
#> Merc 230          -0.009315653 -1.0041833  0.8292133 -0.06562451
#> Merc 280          -0.001433177 -0.7349888  0.4579957 -0.06562451
#>                           wt      qsec         vs       am        gear
#> Mazda RX4          2.4139815 -1.567729  0.2006406  2.88774  0.02512680
#> Mazda RX4 Wag      1.4488706 -0.736286  0.2006406  2.88774  0.02512680
#> Datsun 710         3.5494061  1.624418 -0.3511210  2.88774  0.02512680
#> Hornet 4 Drive     0.1620561  2.856736 -0.3511210 -2.40645 -0.06700481
#> Hornet Sportabout -0.6895124 -0.736286  0.2006406 -2.40645 -0.06700481
#> Valiant           -0.7652074  4.014817 -0.3511210 -2.40645 -0.06700481
#> Duster 360        -1.1815297 -2.488255  0.2006406 -2.40645 -0.06700481
#> Merc 240D          0.2566748  3.688179 -0.3511210 -2.40645  0.02512680
#> Merc 230           0.4080647  7.993866 -0.3511210 -2.40645  0.02512680
#> Merc 280          -0.6895124  1.164155 -0.3511210 -2.40645  0.02512680
#>                         carb
#> Mazda RX4         -0.2497240
#> Mazda RX4 Wag     -0.2497240
#> Datsun 710         0.4668753
#> Hornet 4 Drive     0.4668753
#> Hornet Sportabout  0.2280089
#> Valiant            0.4668753
#> Duster 360        -0.2497240
#> Merc 240D          0.2280089
#> Merc 230           0.2280089
#> Merc 280          -0.2497240
#> attr(,"constant")
#> [1] 19.96364