augment() will add column(s) for predictions to the given data.

# S3 method for model_fit
augment(x, new_data, ...)

Arguments

x

A model_fit object produced by fit() or fit_xy().

new_data

A data frame or matrix.

...

Not currently used.

Details

For regression models, a .pred column is added. If x was created using fit() and new_data contains the outcome column, a .resid column is also added.

For classification models, the results include a column called .pred_class as well as class probability columns named .pred_{level}.

Examples

car_trn <- mtcars[11:32,] car_tst <- mtcars[ 1:10,] reg_form <- linear_reg() %>% set_engine("lm") %>% fit(mpg ~ ., data = car_trn) reg_xy <- linear_reg() %>% set_engine("lm") %>% fit_xy(car_trn[, -1], car_trn$mpg) augment(reg_form, car_tst)
#> mpg cyl disp hp drat wt qsec vs am gear carb .pred #> Mazda RX4 21.0 6 160.0 110 3.90 2.620 16.46 0 1 4 4 23.43006 #> Mazda RX4 Wag 21.0 6 160.0 110 3.90 2.875 17.02 0 1 4 4 23.29639 #> Datsun 710 22.8 4 108.0 93 3.85 2.320 18.61 1 1 4 1 27.62870 #> Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1 21.54678 #> Hornet Sportabout 18.7 8 360.0 175 3.15 3.440 17.02 0 0 3 2 17.57028 #> Valiant 18.1 6 225.0 105 2.76 3.460 20.22 1 0 3 1 21.58316 #> Duster 360 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 13.90725 #> Merc 240D 24.4 4 146.7 62 3.69 3.190 20.00 1 0 4 2 21.70176 #> Merc 230 22.8 4 140.8 95 3.92 3.150 22.90 1 0 4 2 25.61122 #> Merc 280 19.2 6 167.6 123 3.92 3.440 18.30 1 0 4 4 17.11206 #> .resid #> Mazda RX4 -2.4300560 #> Mazda RX4 Wag -2.2963881 #> Datsun 710 -4.8287021 #> Hornet 4 Drive -0.1467751 #> Hornet Sportabout 1.1297224 #> Valiant -3.4831608 #> Duster 360 0.3927513 #> Merc 240D 2.6982354 #> Merc 230 -2.8112207 #> Merc 280 2.0879402
augment(reg_form, car_tst[, -1])
#> cyl disp hp drat wt qsec vs am gear carb .pred #> Mazda RX4 6 160.0 110 3.90 2.620 16.46 0 1 4 4 23.43006 #> Mazda RX4 Wag 6 160.0 110 3.90 2.875 17.02 0 1 4 4 23.29639 #> Datsun 710 4 108.0 93 3.85 2.320 18.61 1 1 4 1 27.62870 #> Hornet 4 Drive 6 258.0 110 3.08 3.215 19.44 1 0 3 1 21.54678 #> Hornet Sportabout 8 360.0 175 3.15 3.440 17.02 0 0 3 2 17.57028 #> Valiant 6 225.0 105 2.76 3.460 20.22 1 0 3 1 21.58316 #> Duster 360 8 360.0 245 3.21 3.570 15.84 0 0 3 4 13.90725 #> Merc 240D 4 146.7 62 3.69 3.190 20.00 1 0 4 2 21.70176 #> Merc 230 4 140.8 95 3.92 3.150 22.90 1 0 4 2 25.61122 #> Merc 280 6 167.6 123 3.92 3.440 18.30 1 0 4 4 17.11206
augment(reg_xy, car_tst)
#> mpg cyl disp hp drat wt qsec vs am gear carb .pred #> Mazda RX4 21.0 6 160.0 110 3.90 2.620 16.46 0 1 4 4 23.43006 #> Mazda RX4 Wag 21.0 6 160.0 110 3.90 2.875 17.02 0 1 4 4 23.29639 #> Datsun 710 22.8 4 108.0 93 3.85 2.320 18.61 1 1 4 1 27.62870 #> Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1 0 3 1 21.54678 #> Hornet Sportabout 18.7 8 360.0 175 3.15 3.440 17.02 0 0 3 2 17.57028 #> Valiant 18.1 6 225.0 105 2.76 3.460 20.22 1 0 3 1 21.58316 #> Duster 360 14.3 8 360.0 245 3.21 3.570 15.84 0 0 3 4 13.90725 #> Merc 240D 24.4 4 146.7 62 3.69 3.190 20.00 1 0 4 2 21.70176 #> Merc 230 22.8 4 140.8 95 3.92 3.150 22.90 1 0 4 2 25.61122 #> Merc 280 19.2 6 167.6 123 3.92 3.440 18.30 1 0 4 4 17.11206
augment(reg_xy, car_tst[, -1])
#> cyl disp hp drat wt qsec vs am gear carb .pred #> Mazda RX4 6 160.0 110 3.90 2.620 16.46 0 1 4 4 23.43006 #> Mazda RX4 Wag 6 160.0 110 3.90 2.875 17.02 0 1 4 4 23.29639 #> Datsun 710 4 108.0 93 3.85 2.320 18.61 1 1 4 1 27.62870 #> Hornet 4 Drive 6 258.0 110 3.08 3.215 19.44 1 0 3 1 21.54678 #> Hornet Sportabout 8 360.0 175 3.15 3.440 17.02 0 0 3 2 17.57028 #> Valiant 6 225.0 105 2.76 3.460 20.22 1 0 3 1 21.58316 #> Duster 360 8 360.0 245 3.21 3.570 15.84 0 0 3 4 13.90725 #> Merc 240D 4 146.7 62 3.69 3.190 20.00 1 0 4 2 21.70176 #> Merc 230 4 140.8 95 3.92 3.150 22.90 1 0 4 2 25.61122 #> Merc 280 6 167.6 123 3.92 3.440 18.30 1 0 4 4 17.11206
# ------------------------------------------------------------------------------ data(two_class_dat, package = "modeldata") cls_trn <- two_class_dat[-(1:10), ] cls_tst <- two_class_dat[ 1:10 , ] cls_form <- logistic_reg() %>% set_engine("glm") %>% fit(Class ~ ., data = cls_trn) cls_xy <- logistic_reg() %>% set_engine("glm") %>% fit_xy(cls_trn[, -3], cls_trn$Class) augment(cls_form, cls_tst)
#> # A tibble: 10 x 6 #> A B Class .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 Class1 0.909 0.0913 #> 3 1.69 1.37 Class2 Class1 0.648 0.352 #> 4 3.43 1.98 Class2 Class1 0.610 0.390 #> 5 2.88 1.98 Class1 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 Class2 0.206 0.794 #> 7 2.50 1.56 Class2 Class1 0.708 0.292 #> 8 1.98 1.55 Class2 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 x 5 #> A B .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 0.909 0.0913 #> 3 1.69 1.37 Class1 0.648 0.352 #> 4 3.43 1.98 Class1 0.610 0.390 #> 5 2.88 1.98 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 0.206 0.794 #> 7 2.50 1.56 Class1 0.708 0.292 #> 8 1.98 1.55 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 0.108 0.892
augment(cls_xy, cls_tst)
#> # A tibble: 10 x 6 #> A B Class .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 Class1 0.909 0.0913 #> 3 1.69 1.37 Class2 Class1 0.648 0.352 #> 4 3.43 1.98 Class2 Class1 0.610 0.390 #> 5 2.88 1.98 Class1 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 Class2 0.206 0.794 #> 7 2.50 1.56 Class2 Class1 0.708 0.292 #> 8 1.98 1.55 Class2 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 Class2 0.108 0.892
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 x 5 #> A B .pred_class .pred_Class1 .pred_Class2 #> <dbl> <dbl> <fct> <dbl> <dbl> #> 1 2.07 1.63 Class1 0.518 0.482 #> 2 2.02 1.04 Class1 0.909 0.0913 #> 3 1.69 1.37 Class1 0.648 0.352 #> 4 3.43 1.98 Class1 0.610 0.390 #> 5 2.88 1.98 Class2 0.443 0.557 #> 6 3.31 2.41 Class2 0.206 0.794 #> 7 2.50 1.56 Class1 0.708 0.292 #> 8 1.98 1.55 Class1 0.567 0.433 #> 9 2.88 0.580 Class1 0.994 0.00582 #> 10 3.74 2.74 Class2 0.108 0.892