augment()
will add column(s) for predictions to the given data.
Usage
# S3 method for model_fit
augment(x, new_data, eval_time = NULL, ...)
Arguments
- x
A
model_fit
object produced byfit.model_spec()
orfit_xy.model_spec()
.- new_data
A data frame or matrix.
- eval_time
For censored regression models, a vector of time points at which the survival probability is estimated.
- ...
Not currently used.
Details
Regression
For regression models, a .pred
column is added. If x
was created using
fit.model_spec()
and new_data
contains a regression outcome column, a
.resid
column is also added.
Classification
For classification models, the results can include a column called
.pred_class
as well as class probability columns named .pred_{level}
.
This depends on what type of prediction types are available for the model.
Censored Regression
For these models, predictions for the expected time and survival probability
are created (if the model engine supports them). If the model supports
survival prediction, the eval_time
argument is required.
If survival predictions are created and new_data
contains a
survival::Surv()
object, additional columns are added for inverse
probability of censoring weights (IPCW) are also created (see tidymodels.org
page in the references below). This enables the user to compute performance
metrics in the yardstick package.
Examples
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#> .pred .resid mpg cyl disp hp drat wt qsec vs am
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 -2.43 21 6 160 110 3.9 2.62 16.5 0 1
#> 2 23.3 -2.30 21 6 160 110 3.9 2.88 17.0 0 1
#> 3 27.6 -4.83 22.8 4 108 93 3.85 2.32 18.6 1 1
#> 4 21.5 -0.147 21.4 6 258 110 3.08 3.22 19.4 1 0
#> 5 17.6 1.13 18.7 8 360 175 3.15 3.44 17.0 0 0
#> 6 21.6 -3.48 18.1 6 225 105 2.76 3.46 20.2 1 0
#> 7 13.9 0.393 14.3 8 360 245 3.21 3.57 15.8 0 0
#> 8 21.7 2.70 24.4 4 147. 62 3.69 3.19 20 1 0
#> 9 25.6 -2.81 22.8 4 141. 95 3.92 3.15 22.9 1 0
#> 10 17.1 2.09 19.2 6 168. 123 3.92 3.44 18.3 1 0
#> # ℹ 2 more variables: gear <dbl>, carb <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#> .pred mpg cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 21 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 21 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 24.4 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#> .pred cyl disp hp drat wt qsec vs am gear carb
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 23.4 6 160 110 3.9 2.62 16.5 0 1 4 4
#> 2 23.3 6 160 110 3.9 2.88 17.0 0 1 4 4
#> 3 27.6 4 108 93 3.85 2.32 18.6 1 1 4 1
#> 4 21.5 6 258 110 3.08 3.22 19.4 1 0 3 1
#> 5 17.6 8 360 175 3.15 3.44 17.0 0 0 3 2
#> 6 21.6 6 225 105 2.76 3.46 20.2 1 0 3 1
#> 7 13.9 8 360 245 3.21 3.57 15.8 0 0 3 4
#> 8 21.7 4 147. 62 3.69 3.19 20 1 0 4 2
#> 9 25.6 4 141. 95 3.92 3.15 22.9 1 0 4 2
#> 10 17.1 6 168. 123 3.92 3.44 18.3 1 0 4 4
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74
augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#> .pred_class .pred_Class1 .pred_Class2 A B Class
#> <fct> <dbl> <dbl> <dbl> <dbl> <fct>
#> 1 Class1 0.518 0.482 2.07 1.63 Class1
#> 2 Class1 0.909 0.0913 2.02 1.04 Class1
#> 3 Class1 0.648 0.352 1.69 1.37 Class2
#> 4 Class1 0.610 0.390 3.43 1.98 Class2
#> 5 Class2 0.443 0.557 2.88 1.98 Class1
#> 6 Class2 0.206 0.794 3.31 2.41 Class2
#> 7 Class1 0.708 0.292 2.50 1.56 Class2
#> 8 Class1 0.567 0.433 1.98 1.55 Class2
#> 9 Class1 0.994 0.00582 2.88 0.580 Class1
#> 10 Class2 0.108 0.892 3.74 2.74 Class2
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#> .pred_class .pred_Class1 .pred_Class2 A B
#> <fct> <dbl> <dbl> <dbl> <dbl>
#> 1 Class1 0.518 0.482 2.07 1.63
#> 2 Class1 0.909 0.0913 2.02 1.04
#> 3 Class1 0.648 0.352 1.69 1.37
#> 4 Class1 0.610 0.390 3.43 1.98
#> 5 Class2 0.443 0.557 2.88 1.98
#> 6 Class2 0.206 0.794 3.31 2.41
#> 7 Class1 0.708 0.292 2.50 1.56
#> 8 Class1 0.567 0.433 1.98 1.55
#> 9 Class1 0.994 0.00582 2.88 0.580
#> 10 Class2 0.108 0.892 3.74 2.74