augment() will add column(s) for predictions to the given data.
Usage
# S3 method for model_fit
augment(x, new_data, ...)Arguments
- x
A
model_fitobject produced byfit.model_spec()orfit_xy.model_spec().- new_data
A data frame or matrix.
- ...
Not currently used.
Details
For regression models, a .pred column is added. If x was created using
fit.model_spec() and new_data contains the outcome column, a .resid column is
also added.
For classification models, the results can include a column called
.pred_class as well as class probability columns named .pred_{level}.
This depends on what type of prediction types are available for the model.
Examples
car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]
reg_form <-
linear_reg() %>%
set_engine("lm") %>%
fit(mpg ~ ., data = car_trn)
reg_xy <-
linear_reg() %>%
set_engine("lm") %>%
fit_xy(car_trn[, -1], car_trn$mpg)
augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#> mpg cyl disp hp drat wt qsec vs am gear carb .pred
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4
#> 2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3
#> 3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6
#> 4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5
#> 5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6
#> 6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6
#> 7 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9
#> 8 24.4 4 147. 62 3.69 3.19 20 1 0 4 2 21.7
#> 9 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6
#> 10 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
#> # … with 1 more variable: .resid <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#> cyl disp hp drat wt qsec vs am gear carb .pred
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4
#> 2 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3
#> 3 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6
#> 4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5
#> 5 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6
#> 6 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6
#> 7 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9
#> 8 4 147. 62 3.69 3.19 20 1 0 4 2 21.7
#> 9 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6
#> 10 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#> mpg cyl disp hp drat wt qsec vs am gear carb .pred
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4
#> 2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3
#> 3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6
#> 4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5
#> 5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6
#> 6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6
#> 7 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9
#> 8 24.4 4 147. 62 3.69 3.19 20 1 0 4 2 21.7
#> 9 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6
#> 10 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#> cyl disp hp drat wt qsec vs am gear carb .pred
#> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#> 1 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4
#> 2 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3
#> 3 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6
#> 4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5
#> 5 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6
#> 6 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6
#> 7 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9
#> 8 4 147. 62 3.69 3.19 20 1 0 4 2 21.7
#> 9 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6
#> 10 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
# ------------------------------------------------------------------------------
data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[ 1:10 , ]
cls_form <-
logistic_reg() %>%
set_engine("glm") %>%
fit(Class ~ ., data = cls_trn)
cls_xy <-
logistic_reg() %>%
set_engine("glm") %>%
fit_xy(cls_trn[, -3],
cls_trn$Class)
augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#> A B Class .pred_class .pred_Class1 .pred_Class2
#> <dbl> <dbl> <fct> <fct> <dbl> <dbl>
#> 1 2.07 1.63 Class1 Class1 0.518 0.482
#> 2 2.02 1.04 Class1 Class1 0.909 0.0913
#> 3 1.69 1.37 Class2 Class1 0.648 0.352
#> 4 3.43 1.98 Class2 Class1 0.610 0.390
#> 5 2.88 1.98 Class1 Class2 0.443 0.557
#> 6 3.31 2.41 Class2 Class2 0.206 0.794
#> 7 2.50 1.56 Class2 Class1 0.708 0.292
#> 8 1.98 1.55 Class2 Class1 0.567 0.433
#> 9 2.88 0.580 Class1 Class1 0.994 0.00582
#> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#> A B .pred_class .pred_Class1 .pred_Class2
#> <dbl> <dbl> <fct> <dbl> <dbl>
#> 1 2.07 1.63 Class1 0.518 0.482
#> 2 2.02 1.04 Class1 0.909 0.0913
#> 3 1.69 1.37 Class1 0.648 0.352
#> 4 3.43 1.98 Class1 0.610 0.390
#> 5 2.88 1.98 Class2 0.443 0.557
#> 6 3.31 2.41 Class2 0.206 0.794
#> 7 2.50 1.56 Class1 0.708 0.292
#> 8 1.98 1.55 Class1 0.567 0.433
#> 9 2.88 0.580 Class1 0.994 0.00582
#> 10 3.74 2.74 Class2 0.108 0.892
augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#> A B Class .pred_class .pred_Class1 .pred_Class2
#> <dbl> <dbl> <fct> <fct> <dbl> <dbl>
#> 1 2.07 1.63 Class1 Class1 0.518 0.482
#> 2 2.02 1.04 Class1 Class1 0.909 0.0913
#> 3 1.69 1.37 Class2 Class1 0.648 0.352
#> 4 3.43 1.98 Class2 Class1 0.610 0.390
#> 5 2.88 1.98 Class1 Class2 0.443 0.557
#> 6 3.31 2.41 Class2 Class2 0.206 0.794
#> 7 2.50 1.56 Class2 Class1 0.708 0.292
#> 8 1.98 1.55 Class2 Class1 0.567 0.433
#> 9 2.88 0.580 Class1 Class1 0.994 0.00582
#> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#> A B .pred_class .pred_Class1 .pred_Class2
#> <dbl> <dbl> <fct> <dbl> <dbl>
#> 1 2.07 1.63 Class1 0.518 0.482
#> 2 2.02 1.04 Class1 0.909 0.0913
#> 3 1.69 1.37 Class1 0.648 0.352
#> 4 3.43 1.98 Class1 0.610 0.390
#> 5 2.88 1.98 Class2 0.443 0.557
#> 6 3.31 2.41 Class2 0.206 0.794
#> 7 2.50 1.56 Class1 0.708 0.292
#> 8 1.98 1.55 Class1 0.567 0.433
#> 9 2.88 0.580 Class1 0.994 0.00582
#> 10 3.74 2.74 Class2 0.108 0.892
