Skip to content

augment() will add column(s) for predictions to the given data.

Usage

# S3 method for model_fit
augment(x, new_data, ...)

Arguments

x

A model_fit object produced by fit.model_spec() or fit_xy.model_spec() .

new_data

A data frame or matrix.

...

Not currently used.

Details

For regression models, a .pred column is added. If x was created using fit.model_spec() and new_data contains the outcome column, a .resid column is also added.

For classification models, the results can include a column called .pred_class as well as class probability columns named .pred_{level}. This depends on what type of prediction types are available for the model.

Examples

car_trn <- mtcars[11:32,]
car_tst <- mtcars[ 1:10,]

reg_form <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit(mpg ~ ., data = car_trn)
reg_xy <-
  linear_reg() %>%
  set_engine("lm") %>%
  fit_xy(car_trn[, -1], car_trn$mpg)

augment(reg_form, car_tst)
#> # A tibble: 10 × 13
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb .pred
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  21       6  160    110  3.9   2.62  16.5     0     1     4     4  23.4
#>  2  21       6  160    110  3.9   2.88  17.0     0     1     4     4  23.3
#>  3  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1  27.6
#>  4  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1  21.5
#>  5  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2  17.6
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1  21.6
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4  13.9
#>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4     2  21.7
#>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2  25.6
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4  17.1
#> # … with 1 more variable: .resid <dbl>
augment(reg_form, car_tst[, -1])
#> # A tibble: 10 × 11
#>      cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb .pred
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1     6  160    110  3.9   2.62  16.5     0     1     4     4  23.4
#>  2     6  160    110  3.9   2.88  17.0     0     1     4     4  23.3
#>  3     4  108     93  3.85  2.32  18.6     1     1     4     1  27.6
#>  4     6  258    110  3.08  3.22  19.4     1     0     3     1  21.5
#>  5     8  360    175  3.15  3.44  17.0     0     0     3     2  17.6
#>  6     6  225    105  2.76  3.46  20.2     1     0     3     1  21.6
#>  7     8  360    245  3.21  3.57  15.8     0     0     3     4  13.9
#>  8     4  147.    62  3.69  3.19  20       1     0     4     2  21.7
#>  9     4  141.    95  3.92  3.15  22.9     1     0     4     2  25.6
#> 10     6  168.   123  3.92  3.44  18.3     1     0     4     4  17.1

augment(reg_xy, car_tst)
#> # A tibble: 10 × 12
#>      mpg   cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb .pred
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1  21       6  160    110  3.9   2.62  16.5     0     1     4     4  23.4
#>  2  21       6  160    110  3.9   2.88  17.0     0     1     4     4  23.3
#>  3  22.8     4  108     93  3.85  2.32  18.6     1     1     4     1  27.6
#>  4  21.4     6  258    110  3.08  3.22  19.4     1     0     3     1  21.5
#>  5  18.7     8  360    175  3.15  3.44  17.0     0     0     3     2  17.6
#>  6  18.1     6  225    105  2.76  3.46  20.2     1     0     3     1  21.6
#>  7  14.3     8  360    245  3.21  3.57  15.8     0     0     3     4  13.9
#>  8  24.4     4  147.    62  3.69  3.19  20       1     0     4     2  21.7
#>  9  22.8     4  141.    95  3.92  3.15  22.9     1     0     4     2  25.6
#> 10  19.2     6  168.   123  3.92  3.44  18.3     1     0     4     4  17.1
augment(reg_xy, car_tst[, -1])
#> # A tibble: 10 × 11
#>      cyl  disp    hp  drat    wt  qsec    vs    am  gear  carb .pred
#>    <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
#>  1     6  160    110  3.9   2.62  16.5     0     1     4     4  23.4
#>  2     6  160    110  3.9   2.88  17.0     0     1     4     4  23.3
#>  3     4  108     93  3.85  2.32  18.6     1     1     4     1  27.6
#>  4     6  258    110  3.08  3.22  19.4     1     0     3     1  21.5
#>  5     8  360    175  3.15  3.44  17.0     0     0     3     2  17.6
#>  6     6  225    105  2.76  3.46  20.2     1     0     3     1  21.6
#>  7     8  360    245  3.21  3.57  15.8     0     0     3     4  13.9
#>  8     4  147.    62  3.69  3.19  20       1     0     4     2  21.7
#>  9     4  141.    95  3.92  3.15  22.9     1     0     4     2  25.6
#> 10     6  168.   123  3.92  3.44  18.3     1     0     4     4  17.1

# ------------------------------------------------------------------------------

data(two_class_dat, package = "modeldata")
cls_trn <- two_class_dat[-(1:10), ]
cls_tst <- two_class_dat[  1:10 , ]

cls_form <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit(Class ~ ., data = cls_trn)
cls_xy <-
  logistic_reg() %>%
  set_engine("glm") %>%
  fit_xy(cls_trn[, -3],
  cls_trn$Class)

augment(cls_form, cls_tst)
#> # A tibble: 10 × 6
#>        A     B Class  .pred_class .pred_Class1 .pred_Class2
#>    <dbl> <dbl> <fct>  <fct>              <dbl>        <dbl>
#>  1  2.07 1.63  Class1 Class1             0.518      0.482  
#>  2  2.02 1.04  Class1 Class1             0.909      0.0913 
#>  3  1.69 1.37  Class2 Class1             0.648      0.352  
#>  4  3.43 1.98  Class2 Class1             0.610      0.390  
#>  5  2.88 1.98  Class1 Class2             0.443      0.557  
#>  6  3.31 2.41  Class2 Class2             0.206      0.794  
#>  7  2.50 1.56  Class2 Class1             0.708      0.292  
#>  8  1.98 1.55  Class2 Class1             0.567      0.433  
#>  9  2.88 0.580 Class1 Class1             0.994      0.00582
#> 10  3.74 2.74  Class2 Class2             0.108      0.892  
augment(cls_form, cls_tst[, -3])
#> # A tibble: 10 × 5
#>        A     B .pred_class .pred_Class1 .pred_Class2
#>    <dbl> <dbl> <fct>              <dbl>        <dbl>
#>  1  2.07 1.63  Class1             0.518      0.482  
#>  2  2.02 1.04  Class1             0.909      0.0913 
#>  3  1.69 1.37  Class1             0.648      0.352  
#>  4  3.43 1.98  Class1             0.610      0.390  
#>  5  2.88 1.98  Class2             0.443      0.557  
#>  6  3.31 2.41  Class2             0.206      0.794  
#>  7  2.50 1.56  Class1             0.708      0.292  
#>  8  1.98 1.55  Class1             0.567      0.433  
#>  9  2.88 0.580 Class1             0.994      0.00582
#> 10  3.74 2.74  Class2             0.108      0.892  

augment(cls_xy, cls_tst)
#> # A tibble: 10 × 6
#>        A     B Class  .pred_class .pred_Class1 .pred_Class2
#>    <dbl> <dbl> <fct>  <fct>              <dbl>        <dbl>
#>  1  2.07 1.63  Class1 Class1             0.518      0.482  
#>  2  2.02 1.04  Class1 Class1             0.909      0.0913 
#>  3  1.69 1.37  Class2 Class1             0.648      0.352  
#>  4  3.43 1.98  Class2 Class1             0.610      0.390  
#>  5  2.88 1.98  Class1 Class2             0.443      0.557  
#>  6  3.31 2.41  Class2 Class2             0.206      0.794  
#>  7  2.50 1.56  Class2 Class1             0.708      0.292  
#>  8  1.98 1.55  Class2 Class1             0.567      0.433  
#>  9  2.88 0.580 Class1 Class1             0.994      0.00582
#> 10  3.74 2.74  Class2 Class2             0.108      0.892  
augment(cls_xy, cls_tst[, -3])
#> # A tibble: 10 × 5
#>        A     B .pred_class .pred_Class1 .pred_Class2
#>    <dbl> <dbl> <fct>              <dbl>        <dbl>
#>  1  2.07 1.63  Class1             0.518      0.482  
#>  2  2.02 1.04  Class1             0.909      0.0913 
#>  3  1.69 1.37  Class1             0.648      0.352  
#>  4  3.43 1.98  Class1             0.610      0.390  
#>  5  2.88 1.98  Class2             0.443      0.557  
#>  6  3.31 2.41  Class2             0.206      0.794  
#>  7  2.50 1.56  Class1             0.708      0.292  
#>  8  1.98 1.55  Class1             0.567      0.433  
#>  9  2.88 0.580 Class1             0.994      0.00582
#> 10  3.74 2.74  Class2             0.108      0.892