augment() will add column(s) for predictions to the given data.

# S3 method for model_fit
augment(x, new_data, ...)

Arguments

x

A model_fit object produced by fit.model_spec() or fit_xy.model_spec() .

new_data

A data frame or matrix.

...

Not currently used.

Details

For regression models, a .pred column is added. If x was created using fit.model_spec() and new_data contains the outcome column, a .resid column is also added.

For classification models, the results can include a column called .pred_class as well as class probability columns named .pred_{level}. This depends on what type of prediction types are available for the model.

Examples

car_trn <- mtcars[11:32,] car_tst <- mtcars[ 1:10,] reg_form <- linear_reg() %>% set_engine("lm") %>% fit(mpg ~ ., data = car_trn) reg_xy <- linear_reg() %>% set_engine("lm") %>% fit_xy(car_trn[, -1], car_trn$mpg) augment(reg_form, car_tst)
#> # A tibble: 10 × 13 #> mpg cyl disp hp drat wt qsec vs am gear carb .pred #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4 #> 2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3 #> 3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6 #> 4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5 #> 5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6 #> 6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6 #> 7 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9 #> 8 24.4 4 147. 62 3.69 3.19 20 1 0 4 2 21.7 #> 9 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6 #> 10 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1 #> # … with 1 more variable: .resid <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11 #> cyl disp hp drat wt qsec vs am gear carb .pred #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4 #> 2 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3 #> 3 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6 #> 4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5 #> 5 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6 #> 6 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6 #> 7 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9 #> 8 4 147. 62 3.69 3.19 20 1 0 4 2 21.7 #> 9 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6 #> 10 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
augment(reg_xy, car_tst)
#> # A tibble: 10 × 12 #> mpg cyl disp hp drat wt qsec vs am gear carb .pred #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 21 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4 #> 2 21 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3 #> 3 22.8 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6 #> 4 21.4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5 #> 5 18.7 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6 #> 6 18.1 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6 #> 7 14.3 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9 #> 8 24.4 4 147. 62 3.69 3.19 20 1 0 4 2 21.7 #> 9 22.8 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6 #> 10 19.2 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11 #> cyl disp hp drat wt qsec vs am gear carb .pred #> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> #> 1 6 160 110 3.9 2.62 16.5 0 1 4 4 23.4 #> 2 6 160 110 3.9 2.88 17.0 0 1 4 4 23.3 #> 3 4 108 93 3.85 2.32 18.6 1 1 4 1 27.6 #> 4 6 258 110 3.08 3.22 19.4 1 0 3 1 21.5 #> 5 8 360 175 3.15 3.44 17.0 0 0 3 2 17.6 #> 6 6 225 105 2.76 3.46 20.2 1 0 3 1 21.6 #> 7 8 360 245 3.21 3.57 15.8 0 0 3 4 13.9 #> 8 4 147. 62 3.69 3.19 20 1 0 4 2 21.7 #> 9 4 141. 95 3.92 3.15 22.9 1 0 4 2 25.6 #> 10 6 168. 123 3.92 3.44 18.3 1 0 4 4 17.1
# ------------------------------------------------------------------------------ data(two_class_dat, package = "modeldata") cls_trn <- two_class_dat[-(1:10), ] cls_tst <- two_class_dat[ 1:10 , ] cls_form <- logistic_reg() %>% set_engine("glm") %>% fit(Class ~ ., data = cls_trn) cls_xy <- logistic_reg() %>% set_engine("glm") %>% fit_xy(cls_trn[, -3], cls_trn$Class) augment(cls_form, cls_tst)
#> # A tibble: 10 × 6 #> A B Class .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 Class1 0.909 0.0913 #> 3 1.69 1.37 Class2 Class1 0.648 0.352 #> 4 3.43 1.98 Class2 Class1 0.610 0.390 #> 5 2.88 1.98 Class1 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 Class2 0.206 0.794 #> 7 2.50 1.56 Class2 Class1 0.708 0.292 #> 8 1.98 1.55 Class2 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5 #> A B .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 0.909 0.0913 #> 3 1.69 1.37 Class1 0.648 0.352 #> 4 3.43 1.98 Class1 0.610 0.390 #> 5 2.88 1.98 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 0.206 0.794 #> 7 2.50 1.56 Class1 0.708 0.292 #> 8 1.98 1.55 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 0.108 0.892
augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6 #> A B Class .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 Class1 0.909 0.0913 #> 3 1.69 1.37 Class2 Class1 0.648 0.352 #> 4 3.43 1.98 Class2 Class1 0.610 0.390 #> 5 2.88 1.98 Class1 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 Class2 0.206 0.794 #> 7 2.50 1.56 Class2 Class1 0.708 0.292 #> 8 1.98 1.55 Class2 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5 #> A B .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 0.909 0.0913 #> 3 1.69 1.37 Class1 0.648 0.352 #> 4 3.43 1.98 Class1 0.610 0.390 #> 5 2.88 1.98 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 0.206 0.794 #> 7 2.50 1.56 Class1 0.708 0.292 #> 8 1.98 1.55 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 0.108 0.892