The goal of this vignette is to explain how to ResamplingVariableSizeTrainCV, which can be used to determine how many train data are necessary to provide accurate predictions on a given test set.

Simulated regression problems

The code below creates data for simulated regression problems. First we define a vector of input values,

N <- 300
abs.x <- 10
set.seed(1)
x.vec <- runif(N, -abs.x, abs.x)
str(x.vec)
#>  num [1:300] -4.69 -2.56 1.46 8.16 -5.97 ...

Below we define a list of two true regression functions (tasks in mlr3 terminology) for our simulated data,

reg.pattern.list <- list(
  sin=sin,
  constant=function(x)0)

The constant function represents a regression problem which can be solved by always predicting the mean value of outputs (featureless is the best possible learning algorithm). The sin function will be used to generate data with a non-linear pattern that will need to be learned. Below we use a for loop over these two functions/tasks, to simulate the data which will be used as input to the learning algorithms:

library(data.table)
reg.task.list <- list()
reg.data.list <- list()
for(task_id in names(reg.pattern.list)){
  f <- reg.pattern.list[[task_id]]
  task.dt <- data.table(
    x=x.vec,
    y = f(x.vec)+rnorm(N,sd=0.5))
  reg.data.list[[task_id]] <- data.table(task_id, task.dt)
  reg.task.list[[task_id]] <- mlr3::TaskRegr$new(
    task_id, task.dt, target="y"
  )
}
(reg.data <- rbindlist(reg.data.list))
#>       task_id         x          y
#>        <char>     <num>      <num>
#>   1:      sin -4.689827  1.2248390
#>   2:      sin -2.557522 -0.5607042
#>   3:      sin  1.457067  0.8345056
#>   4:      sin  8.164156  0.4875994
#>   5:      sin -5.966361 -0.4321800
#>  ---                              
#> 596: constant  3.628850 -0.6728968
#> 597: constant -8.016618  0.5168327
#> 598: constant -7.621949 -0.4058882
#> 599: constant -8.991207  0.9008627
#> 600: constant  8.585078  0.8857710

In the table above, the input is x, and the output is y. Below we visualize these data, with one task in each facet/panel:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x, y),
      data=reg.data)+
    facet_grid(task_id ~ ., labeller=label_both)
}

In the plot above we can see two different simulated data sets (constant and sin). Note that the code above used the animint2 package, which provides interactive extensions to the static graphics of the ggplot2 package (see below section Interactive data viz).

Visualizing instance table

In the code below, we define a K-fold cross-validation experiment, with K=3 folds.

reg_size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
reg_size_cv$param_set$values$train_sizes <- 6
reg_size_cv
#> <ResamplingVariableSizeTrainCV> : Cross-Validation with variable size train sets
#> * Iterations:
#> * Instantiated: FALSE
#> * Parameters:
#> List of 4
#>  $ folds         : int 3
#>  $ min_train_data: int 10
#>  $ random_seeds  : int 3
#>  $ train_sizes   : int 6

In the output above we can see the parameters of the resampling object, all of which should be integer scalars:

  • folds is the number of cross-validation folds.
  • min_train_data is the minimum number of train data to consider.
  • random_seeds is the number of random seeds, each of which determines a different random ordering of the train data. The random ordering determines which data are included in small train set sizes.
  • train_sizes is the number of train set sizes, evenly spaced on a log scale, from min_train_data to the max number of train data (determined by folds).

Below we instantiate the resampling on one of the tasks:

reg_size_cv$instantiate(reg.task.list[["sin"]])
reg_size_cv$instance
#> $iteration.dt
#>     test.fold  seed small_stratum_size train_size_i train_size
#>         <int> <int>              <int>        <int>      <int>
#>  1:         1     1                 10            1         10
#>  2:         1     1                 18            2         18
#>  3:         1     1                 33            3         33
#>  4:         1     1                 60            4         60
#>  5:         1     1                110            5        110
#>  6:         1     1                200            6        200
#>  7:         1     2                 10            1         10
#>  8:         1     2                 18            2         18
#>  9:         1     2                 33            3         33
#> 10:         1     2                 60            4         60
#> 11:         1     2                110            5        110
#> 12:         1     2                200            6        200
#> 13:         1     3                 10            1         10
#> 14:         1     3                 18            2         18
#> 15:         1     3                 33            3         33
#> 16:         1     3                 60            4         60
#> 17:         1     3                110            5        110
#> 18:         1     3                200            6        200
#> 19:         2     1                 10            1         10
#> 20:         2     1                 18            2         18
#> 21:         2     1                 33            3         33
#> 22:         2     1                 60            4         60
#> 23:         2     1                110            5        110
#> 24:         2     1                200            6        200
#> 25:         2     2                 10            1         10
#> 26:         2     2                 18            2         18
#> 27:         2     2                 33            3         33
#> 28:         2     2                 60            4         60
#> 29:         2     2                110            5        110
#> 30:         2     2                200            6        200
#> 31:         2     3                 10            1         10
#> 32:         2     3                 18            2         18
#> 33:         2     3                 33            3         33
#> 34:         2     3                 60            4         60
#> 35:         2     3                110            5        110
#> 36:         2     3                200            6        200
#> 37:         3     1                 10            1         10
#> 38:         3     1                 18            2         18
#> 39:         3     1                 33            3         33
#> 40:         3     1                 60            4         60
#> 41:         3     1                110            5        110
#> 42:         3     1                200            6        200
#> 43:         3     2                 10            1         10
#> 44:         3     2                 18            2         18
#> 45:         3     2                 33            3         33
#> 46:         3     2                 60            4         60
#> 47:         3     2                110            5        110
#> 48:         3     2                200            6        200
#> 49:         3     3                 10            1         10
#> 50:         3     3                 18            2         18
#> 51:         3     3                 33            3         33
#> 52:         3     3                 60            4         60
#> 53:         3     3                110            5        110
#> 54:         3     3                200            6        200
#>     test.fold  seed small_stratum_size train_size_i train_size
#>                           train                  test iteration train_min_size
#>                          <list>                <list>     <int>          <int>
#>  1: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         1             10
#>  2: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         2             18
#>  3: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         3             33
#>  4: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         4             60
#>  5: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         5            110
#>  6: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         6            200
#>  7: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         7             10
#>  8: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         8             18
#>  9: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...         9             33
#> 10: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        10             60
#> 11: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        11            110
#> 12: 260,291, 16,164,109, 45,...  1, 7,11,13,15,19,...        12            200
#> 13:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        13             10
#> 14:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        14             18
#> 15:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        15             33
#> 16:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        16             60
#> 17:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        17            110
#> 18:  14,253,115,102,293, 18,...  1, 7,11,13,15,19,...        18            200
#> 19: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        19             10
#> 20: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        20             18
#> 21: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        21             33
#> 22: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        22             60
#> 23: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        23            110
#> 24: 203,197, 81,171,130, 43,...  4, 6, 9,12,14,16,...        24            200
#> 25: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        25             10
#> 26: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        26             18
#> 27: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        27             33
#> 28: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        28             60
#> 29: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        29            110
#> 30: 251,291, 19,164,109, 55,...  4, 6, 9,12,14,16,...        30            200
#> 31:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        31             10
#> 32:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        32             18
#> 33:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        33             33
#> 34:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        34             60
#> 35:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        35            110
#> 36:  15,253,115,110,293, 18,...  4, 6, 9,12,14,16,...        36            200
#> 37: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        37             10
#> 38: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        38             18
#> 39: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        39             33
#> 40: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        40             60
#> 41: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        41            110
#> 42: 203,211, 82,194,130, 43,...  2, 3, 5, 8,10,17,...        42            200
#> 43: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        43             10
#> 44: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        44             18
#> 45: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        45             33
#> 46: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        46             60
#> 47: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        47            110
#> 48: 251,295, 19,189,102, 55,...  2, 3, 5, 8,10,17,...        48            200
#> 49:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        49             10
#> 50:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        50             18
#> 51:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        51             33
#> 52:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        52             60
#> 53:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        53            110
#> 54:  15,263,135,110,296, 25,...  2, 3, 5, 8,10,17,...        54            200
#>                           train                  test iteration train_min_size
#> 
#> $id.dt
#>      row_id  fold
#>       <int> <int>
#>   1:      1     1
#>   2:      2     3
#>   3:      3     3
#>   4:      4     2
#>   5:      5     3
#>  ---             
#> 296:    296     2
#> 297:    297     1
#> 298:    298     1
#> 299:    299     3
#> 300:    300     2

Above we see the instance, which need not be examined by the user, but for informational purposes, it contains the following data:

  • iteration.dt has one row for each train/test split,
  • id.dt has one row for each data point.

Benchmark: computing test error

In the code below, we define two learners to compare,

(reg.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerRegrRpart$new(),
  mlr3::LearnerRegrFeatureless$new()))
#> [[1]]
#> <LearnerRegrRpart:regr.rpart>: Regression Tree
#> * Model: -
#> * Parameters: xval=0
#> * Packages: mlr3, rpart
#> * Predict Types:  [response]
#> * Feature Types: logical, integer, numeric, factor, ordered
#> * Properties: importance, missings, selected_features, weights
#> 
#> [[2]]
#> <LearnerRegrFeatureless:regr.featureless>: Featureless Regression Learner
#> * Model: -
#> * Parameters: robust=FALSE
#> * Packages: mlr3, stats
#> * Predict Types:  [response], se
#> * Feature Types: logical, integer, numeric, character, factor, ordered,
#>   POSIXct
#> * Properties: featureless, importance, missings, selected_features

The code above defines

  • regr.rpart: Regression Tree learning algorithm, which should be able to learn the non-linear pattern in the sin data (if there are enough data in the train set).
  • regr.featureless: Featureless Regression learning algorithm, which should be optimal for the constant data, and can be used as a baseline in the sin data. When the rpart learner gets smaller prediction error rates than featureless, then we know that it has learned some non-trivial relationship between inputs and outputs.

In the code below, we define the benchmark grid, which is all combinations of tasks (constant and sin), learners (rpart and featureless), and the one resampling method.

(reg.bench.grid <- mlr3::benchmark_grid(
  reg.task.list,
  reg.learner.list,
  reg_size_cv))
#>        task          learner             resampling
#>      <char>           <char>                 <char>
#> 1:      sin       regr.rpart variable_size_train_cv
#> 2:      sin regr.featureless variable_size_train_cv
#> 3: constant       regr.rpart variable_size_train_cv
#> 4: constant regr.featureless variable_size_train_cv

In the code below, we execute the benchmark experiment (optionally in parallel using the multisession future plan).

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(reg.bench.result <- mlr3::benchmark(
  reg.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 216 rows with 4 resampling runs
#>  nr  task_id       learner_id          resampling_id iters warnings errors
#>   1      sin       regr.rpart variable_size_train_cv    54        0      0
#>   2      sin regr.featureless variable_size_train_cv    54        0      0
#>   3 constant       regr.rpart variable_size_train_cv    54        0      0
#>   4 constant regr.featureless variable_size_train_cv    54        0      0

The code below computes the test error for each split, and visualizes the information stored in the first row of the result:

reg.bench.score <- mlr3resampling::score(reg.bench.result)
reg.bench.score[1]
#>    test.fold  seed small_stratum_size train_size_i train_size
#>        <int> <int>              <int>        <int>      <int>
#> 1:         1     1                 10            1         10
#>                          train                  test iteration train_min_size
#>                         <list>                <list>     <int>          <int>
#> 1: 216,197, 81,171,143, 36,...  1, 7,11,13,15,19,...         1             10
#>                                   uhash    nr           task task_id
#>                                  <char> <int>         <list>  <char>
#> 1: 132b693f-6201-4977-8de0-9f4e7ea70736     1 <TaskRegr:sin>     sin
#>                          learner learner_id                      resampling
#>                           <list>     <char>                          <list>
#> 1: <LearnerRegrRpart:regr.rpart> regr.rpart <ResamplingVariableSizeTrainCV>
#>             resampling_id       prediction  regr.mse algorithm
#>                    <char>           <list>     <num>    <char>
#> 1: variable_size_train_cv <PredictionRegr> 0.8008255     rpart

The output above contains all of the results related to a particular train/test split. In particular for our purposes, the interesting columns are:

  • test.fold is the cross-validation fold ID.
  • seed is the random seed used to determine the train set order.
  • train_size is the number of data in the train set.
  • train and test are vectors of row numbers assigned to each set.
  • iteration is an ID for the train/test split, for a particular learning algorithm and task. It is the row number of iteration.dt (see instance above), which has one row for each unique combination of test.fold, seed, and train_size.
  • learner is the mlr3 learner object, which can be used to compute predictions on new data (including a grid of inputs, to show predictions in the visualization below).
  • regr.mse is the mean squared error on the test set.
  • algorithm is the name of the learning algorithm (same as learner_id but without regr. prefix).

The code below visualizes the resulting test accuracy numbers.

train_size_vec <- unique(reg.bench.score$train_size)
if(require(animint2)){
  ggplot()+
    scale_x_log10(
      breaks=train_size_vec)+
    scale_y_log10()+
    geom_line(aes(
      train_size, regr.mse,
      group=paste(algorithm, seed),
      color=algorithm),
      shape=1,
      data=reg.bench.score)+
    geom_point(aes(
      train_size, regr.mse, color=algorithm),
      shape=1,
      data=reg.bench.score)+
    facet_grid(
      test.fold~task_id,
      labeller=label_both,
      scales="free")
}

Above we plot the test error for each fold and train set size. There is a different panel for each task and test fold. Each line represents a random seed (ordering of data in train set), and each dot represents a specific train set size. So the plot above shows that some variation in test error, for a given test fold, is due to the random ordering of the train data.

Below we summarize each train set size, by taking the mean and standard deviation over each random seed.

reg.mean.dt <- dcast(
  reg.bench.score,
  task_id + train_size + test.fold + algorithm ~ .,
  list(mean, sd),
  value.var="regr.mse")
if(require(animint2)){
  ggplot()+
    scale_x_log10(
      breaks=train_size_vec)+
    scale_y_log10()+
    geom_ribbon(aes(
      train_size,
      ymin=regr.mse_mean-regr.mse_sd,
      ymax=regr.mse_mean+regr.mse_sd,
      fill=algorithm),
      alpha=0.5,
      data=reg.mean.dt)+
    geom_line(aes(
      train_size, regr.mse_mean, color=algorithm),
      shape=1,
      data=reg.mean.dt)+
    facet_grid(
      test.fold~task_id,
      labeller=label_both,
      scales="free")
}

The plot above shows a line for the mean, and a ribbon for the standard deviation, over the three random seeds. It is clear from the plot above that

  • in constant task, the featureless always has smaller or equal prediction error rates than rpart, which indicates that rpart sometimes overfits for large sample sizes.
  • in sin task, more than 30 samples are required for rpart to be more accurate than featureless, which indicates it has learned a non-trivial relationship between input and output.

Interactive data viz

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

grid.dt <- data.table(x=seq(-abs.x, abs.x, l=101), y=0)
grid.task <- mlr3::TaskRegr$new("grid", grid.dt, target="y")
pred.dt.list <- list()
point.dt.list <- list()
for(score.i in 1:nrow(reg.bench.score)){
  reg.bench.row <- reg.bench.score[score.i]
  task.dt <- data.table(
    reg.bench.row$task[[1]]$data(),
    reg.bench.row$resampling[[1]]$instance$id.dt)
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=reg.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ]
  point.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(task_id, iteration)],
    i.points)
  i.learner <- reg.bench.row$learner[[1]]
  pred.dt.list[[score.i]] <- data.table(
    reg.bench.row[, .(
      task_id, iteration, algorithm
    )],
    as.data.table(
      i.learner$predict(grid.task)
    )[, .(x=grid.dt$x, y=response)]
  )
}
(pred.dt <- rbindlist(pred.dt.list))
#>         task_id iteration   algorithm     x           y
#>          <char>     <int>      <char> <num>       <num>
#>     1:      sin         1       rpart -10.0  0.25011658
#>     2:      sin         1       rpart  -9.8  0.25011658
#>     3:      sin         1       rpart  -9.6  0.25011658
#>     4:      sin         1       rpart  -9.4  0.25011658
#>     5:      sin         1       rpart  -9.2  0.25011658
#>    ---                                                 
#> 21812: constant        54 featureless   9.2 -0.03385654
#> 21813: constant        54 featureless   9.4 -0.03385654
#> 21814: constant        54 featureless   9.6 -0.03385654
#> 21815: constant        54 featureless   9.8 -0.03385654
#> 21816: constant        54 featureless  10.0 -0.03385654
(point.dt <- rbindlist(point.dt.list))
#>         task_id iteration set.name row_id          y         x  fold
#>          <char>     <int>   <char>  <int>      <num>     <num> <int>
#>     1:      sin         1     test      1  1.2248390 -4.689827     1
#>     2:      sin         1   unused      2 -0.5607042 -2.557522     3
#>     3:      sin         1   unused      3  0.8345056  1.457067     3
#>     4:      sin         1   unused      4  0.4875994  8.164156     2
#>     5:      sin         1   unused      5 -0.4321800 -5.966361     3
#>    ---                                                              
#> 64796: constant        54    train    296 -0.6728968  3.628850     2
#> 64797: constant        54    train    297  0.5168327 -8.016618     1
#> 64798: constant        54    train    298 -0.4058882 -7.621949     1
#> 64799: constant        54     test    299  0.9008627 -8.991207     3
#> 64800: constant        54    train    300  0.8857710  8.585078     2
set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
if(require(animint2)){
  viz <- animint(
    title="Variable size train set, regression",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme_animint(height=400)+
      scale_fill_manual(values=set.colors)+
      geom_point(aes(
        x, y, fill=set.name),
        showSelected="iteration",
        size=3,
        shape=21,
        data=point.dt)+
      scale_size_manual(values=c(
        featureless=3,
        rpart=2))+
      scale_color_manual(values=algo.colors)+
      geom_line(aes(
        x, y,
        color=algorithm,
        size=algorithm,
        group=paste(algorithm, iteration)),
        showSelected="iteration",
        data=pred.dt)+
      facet_grid(
        task_id ~ .,
        labeller=label_both),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(width=500)+
      theme(
        panel.margin=grid::unit(1, "lines"),
        legend.position="none")+
      scale_y_log10(
        "Mean squared error on test set")+
      scale_color_manual(values=algo.colors)+
      scale_x_log10(
        "Train set size",
        breaks=train_size_vec)+
      geom_line(aes(
        train_size, regr.mse,
        group=paste(algorithm, seed),
        color=algorithm),
        clickSelects="seed",
        alpha_off=0.2,
        showSelected="algorithm",
        size=4,
        data=reg.bench.score)+
      facet_grid(
        test.fold~task_id,
        labeller=label_both,
        scales="free")+
      geom_point(aes(
        train_size, regr.mse,
        color=algorithm),
        size=5,
        stroke=3,
        fill="black",
        fill_off=NA,
        showSelected=c("algorithm","seed"),
        clickSelects="iteration",
        data=reg.bench.score),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/Simulations.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-26-train-sizes-regression")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-26-train-sizes-regression/

The interactive data viz consists of two plots:

  • The first plot shows the data, with each point colored according to the set it was assigned, in the currently selected split/iteration. The red/blue lines additionally show the learned prediction functions for the currently selected split/iteration.
  • The second plot shows the test error rates, as a function of train set size. Clicking a line selects the corresponding random seed, which makes the corresponding points on that line appear. Clicking a point selects the corresponding iteration (seed, test fold, and train set size).

Simulated classification problems

Whereas in the section above, we focused on regression (output is a real number), in this section we simulate a binary classification problem (output if a factor with two levels).

class.N <- 300
class.abs.x <- 1
rclass <- function(){
  runif(class.N, -class.abs.x, class.abs.x)
}
library(data.table)
set.seed(1)
class.x.dt <- data.table(x1=rclass(), x2=rclass())
class.fun.list <- list(
  constant=function(...)0.5,
  xor=function(x1, x2)xor(x1>0, x2>0))
class.data.list <- list()
class.task.list <- list()
for(task_id in names(class.fun.list)){
  class.fun <- class.fun.list[[task_id]]
  y <- factor(ifelse(
    class.x.dt[, class.fun(x1, x2)+rnorm(class.N, sd=0.5)]>0.5,
    "spam", "not"))
  task.dt <- data.table(class.x.dt, y)
  this.task <- mlr3::TaskClassif$new(
    task_id, task.dt, target="y")
  this.task$col_roles$stratum <- "y"
  class.task.list[[task_id]] <- this.task
  class.data.list[[task_id]] <- data.table(task_id, task.dt)
}
(class.data <- rbindlist(class.data.list))
#>       task_id         x1           x2      y
#>        <char>      <num>        <num> <fctr>
#>   1: constant -0.4689827  0.347424466   spam
#>   2: constant -0.2557522 -0.810284289    not
#>   3: constant  0.1457067 -0.014807758   spam
#>   4: constant  0.8164156 -0.076896319    not
#>   5: constant -0.5966361 -0.249566938   spam
#>  ---                                        
#> 596:      xor  0.3628850  0.297101895    not
#> 597:      xor -0.8016618 -0.040328411    not
#> 598:      xor -0.7621949 -0.009871789   spam
#> 599:      xor -0.8991207 -0.240254817    not
#> 600:      xor  0.8585078 -0.099029126   spam

The simulated data table above consists of two input features (x1 and x2) along with an output/label to predict (y). Below we count the number of times each label appears in each task:

class.data[, .(count=.N), by=.(task_id, y)]
#>     task_id      y count
#>      <char> <fctr> <int>
#> 1: constant   spam   143
#> 2: constant    not   157
#> 3:      xor   spam   145
#> 4:      xor    not   155

The table above shows that the spam label is the minority class (not is majority, so that will be the prediction of the featureless baseline). Below we visualize the data in the feature space:

if(require(animint2)){
  ggplot()+
    geom_point(aes(
      x1, x2, color=y),
      shape=1,
      data=class.data)+
    facet_grid(. ~ task_id, labeller=label_both)+
    coord_equal()
}

The plot above shows how the output y is related to the two inputs x1 and x2, for the two tasks.

In the mlr3 code below, we define a list of learners, our resampling method, and a benchmark grid:

class.learner.list <- list(
  if(requireNamespace("rpart"))mlr3::LearnerClassifRpart$new(),
  mlr3::LearnerClassifFeatureless$new())
size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
(class.bench.grid <- mlr3::benchmark_grid(
  class.task.list,
  class.learner.list,
  size_cv))
#>        task             learner             resampling
#>      <char>              <char>                 <char>
#> 1: constant       classif.rpart variable_size_train_cv
#> 2: constant classif.featureless variable_size_train_cv
#> 3:      xor       classif.rpart variable_size_train_cv
#> 4:      xor classif.featureless variable_size_train_cv

Below we run the learning algorithm for each of the train/test splits defined by our benchmark grid:

if(FALSE){
  if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
  class.bench.grid, store_models = TRUE))
#> <BenchmarkResult> of 180 rows with 4 resampling runs
#>  nr  task_id          learner_id          resampling_id iters warnings errors
#>   1 constant       classif.rpart variable_size_train_cv    45        0      0
#>   2 constant classif.featureless variable_size_train_cv    45        0      0
#>   3      xor       classif.rpart variable_size_train_cv    45        0      0
#>   4      xor classif.featureless variable_size_train_cv    45        0      0

Below we compute scores (test error) for each resampling iteration, and show the first row of the result.

class.bench.score <- mlr3resampling::score(class.bench.result)
class.bench.score[1]
#>    test.fold  seed small_stratum_size train_size_i train_size
#>        <int> <int>              <int>        <int>      <int>
#> 1:         1     1                 10            1         21
#>                          train                  test iteration train_min_size
#>                         <list>                <list>     <int>          <int>
#> 1: 132,239, 10,216,245,276,...  5, 6, 8,21,23,28,...         1             21
#>                                   uhash    nr                   task  task_id
#>                                  <char> <int>                 <list>   <char>
#> 1: 936db87d-1a46-427e-abb6-d67b1e877686     1 <TaskClassif:constant> constant
#>                                learner    learner_id
#>                                 <list>        <char>
#> 1: <LearnerClassifRpart:classif.rpart> classif.rpart
#>                         resampling          resampling_id          prediction
#>                             <list>                 <char>              <list>
#> 1: <ResamplingVariableSizeTrainCV> variable_size_train_cv <PredictionClassif>
#>    classif.ce algorithm
#>         <num>    <char>
#> 1:  0.4257426     rpart

The output above has columns which are very similar to the regression example in the previous section. The main difference is the classif.ce column, which is the classification error on the test set.

Finally we plot the test error values below.

if(require(animint2)){
  ggplot()+
    geom_line(aes(
      train_size, classif.ce,
      group=paste(algorithm, seed),
      color=algorithm),
      shape=1,
      data=class.bench.score)+
    geom_point(aes(
      train_size, classif.ce, color=algorithm),
      shape=1,
      data=class.bench.score)+
    facet_grid(
      task_id ~ test.fold,
      labeller=label_both,
      scales="free")+
    scale_x_log10()
}

It is clear from the plot above that

Exercise for the reader: compute and plot mean and SD for these classification tasks, similar to the plot for the regression tasks in the previous section.

Interactive visualization of data, test error, and splits

The code below can be used to create an interactive data visualization which allows exploring how different functions are learned during different splits.

class.grid.vec <- seq(-class.abs.x, class.abs.x, l=21)
class.grid.dt <- CJ(x1=class.grid.vec, x2=class.grid.vec)
class.pred.dt.list <- list()
class.point.dt.list <- list()
for(score.i in 1:nrow(class.bench.score)){
  class.bench.row <- class.bench.score[score.i]
  task.dt <- data.table(
    class.bench.row$task[[1]]$data(),
    class.bench.row$resampling[[1]]$instance$id.dt)
  set.ids <- data.table(
    set.name=c("test","train")
  )[
  , data.table(row_id=class.bench.row[[set.name]][[1]])
  , by=set.name]
  i.points <- set.ids[
    task.dt, on="row_id"
  ][
    is.na(set.name), set.name := "unused"
  ][]
  class.point.dt.list[[score.i]] <- data.table(
    class.bench.row[, .(task_id, iteration)],
    i.points)
  if(class.bench.row$algorithm!="featureless"){
    i.learner <- class.bench.row$learner[[1]]
    i.learner$predict_type <- "prob"
    i.task <- class.bench.row$task[[1]]
    grid.class.task <- mlr3::TaskClassif$new(
      "grid", class.grid.dt[, label:=factor(NA,levels(task.dt$y))], target="label")
    pred.grid <- as.data.table(
      i.learner$predict(grid.class.task)
    )[, data.table(class.grid.dt, prob.spam)]
    pred.wide <- dcast(pred.grid, x1 ~ x2, value.var="prob.spam")
    prob.mat <- as.matrix(pred.wide[,-1])
    if(length(table(prob.mat))>1){
      contour.list <- contourLines(
        class.grid.vec, class.grid.vec, prob.mat, levels=0.5)
      class.pred.dt.list[[score.i]] <- data.table(
        class.bench.row[, .(
          task_id, iteration, algorithm
        )],
        data.table(contour.i=seq_along(contour.list))[, {
          do.call(data.table, contour.list[[contour.i]])[, .(level, x1=x, x2=y)]
        }, by=contour.i]
      )
    }
  }
}
(class.pred.dt <- rbindlist(class.pred.dt.list))
#>        task_id iteration algorithm contour.i level     x1          x2
#>         <char>     <int>    <char>     <int> <num>  <num>       <num>
#>    1: constant         1     rpart         1   0.5 0.0375 -1.00000000
#>    2: constant         1     rpart         1   0.5 0.0375 -0.90000000
#>    3: constant         1     rpart         1   0.5 0.0375 -0.80000000
#>    4: constant         1     rpart         1   0.5 0.0375 -0.70000000
#>    5: constant         1     rpart         1   0.5 0.0375 -0.60000000
#>   ---                                                                
#> 5190:      xor        45     rpart         2   0.5 0.6000  0.04888889
#> 5191:      xor        45     rpart         2   0.5 0.7000  0.04888889
#> 5192:      xor        45     rpart         2   0.5 0.8000  0.04888889
#> 5193:      xor        45     rpart         2   0.5 0.9000  0.04888889
#> 5194:      xor        45     rpart         2   0.5 1.0000  0.04888889
(class.point.dt <- rbindlist(class.point.dt.list))
#>         task_id iteration set.name row_id      y         x1           x2  fold
#>          <char>     <int>   <char>  <int> <fctr>      <num>        <num> <int>
#>     1: constant         1   unused      1   spam -0.4689827  0.347424466     3
#>     2: constant         1   unused      2    not -0.2557522 -0.810284289     2
#>     3: constant         1   unused      3   spam  0.1457067 -0.014807758     3
#>     4: constant         1    train      4    not  0.8164156 -0.076896319     3
#>     5: constant         1     test      5   spam -0.5966361 -0.249566938     1
#>    ---                                                                        
#> 53996:      xor        45    train    296    not  0.3628850  0.297101895     2
#> 53997:      xor        45    train    297    not -0.8016618 -0.040328411     2
#> 53998:      xor        45     test    298   spam -0.7621949 -0.009871789     3
#> 53999:      xor        45     test    299    not -0.8991207 -0.240254817     3
#> 54000:      xor        45    train    300   spam  0.8585078 -0.099029126     2

set.colors <- c(
  train="#1B9E77",
  test="#D95F02",
  unused="white")
algo.colors <- c(
  featureless="blue",
  rpart="red")
if(require(animint2)){
  viz <- animint(
    title="Variable size train sets, classification",
    pred=ggplot()+
      ggtitle("Predictions for selected train/test split")+
      theme(panel.margin=grid::unit(1, "lines"))+
      theme_animint(width=600)+
      coord_equal()+
      scale_fill_manual(values=set.colors)+
      scale_color_manual(values=c(spam="black","not spam"="white"))+
      geom_point(aes(
        x1, x2, color=y, fill=set.name),
        showSelected="iteration",
        size=3,
        stroke=2,
        shape=21,
        data=class.point.dt)+
      geom_path(aes(
        x1, x2, 
        group=paste(algorithm, iteration, contour.i)),
        showSelected=c("iteration","algorithm"),
        color=algo.colors[["rpart"]],
        data=class.pred.dt)+
      facet_grid(
        . ~ task_id,
        labeller=label_both,
        space="free",
        scales="free"),
    err=ggplot()+
      ggtitle("Test error for each split")+
      theme_animint(height=400)+
      theme(panel.margin=grid::unit(1, "lines"))+
      scale_y_continuous(
        "Classification error on test set")+
      scale_color_manual(values=algo.colors)+
      scale_x_log10(
        "Train set size")+
      geom_line(aes(
        train_size, classif.ce,
        group=paste(algorithm, seed),
        color=algorithm),
        clickSelects="seed",
        alpha_off=0.2,
        showSelected="algorithm",
        size=4,
        data=class.bench.score)+
      facet_grid(
        test.fold~task_id,
        labeller=label_both,
        scales="free")+
      geom_point(aes(
        train_size, classif.ce,
        color=algorithm),
        size=5,
        stroke=3,
        fill="black",
        fill_off=NA,
        showSelected=c("algorithm","seed"),
        clickSelects="iteration",
        data=class.bench.score),
    source="https://github.com/tdhock/mlr3resampling/blob/main/vignettes/ResamplingVariableSizeTrainCV.Rmd")
  viz
}

if(FALSE){
  animint2pages(viz, "2023-12-27-train-sizes-classification")
}

If you are viewing this in an installed package or on CRAN, then there will be no data viz on this page, but you can view it on: https://tdhock.github.io/2023-12-27-train-sizes-classification/

The interactive data viz consists of two plots

  • The first plot shows the data, with each point colored according to its label/y value (black outline for spam, white outline for not), and the set it was assigned (fill color) in the currently selected split/iteration. The red lines additionally show the learned decision boundary for rpart, given the currently selected split/iteration. For constant, the ideal decision boundary is none (always predict the most frequent class), and for xor, the ideal decision boundary looks like a plus sign.
  • The second plot shows the test error rates, as a function of train set size. Clicking a line selects the corresponding random seed, which makes the corresponding points on that line appear. Clicking a point selects the corresponding iteration (seed, test fold, and train set size).

Conclusion

In this vignette we have shown how to use mlr3resampling for comparing test error of models trained on different sized train sets.

Session info

sessionInfo()
#> R Under development (unstable) (2024-01-23 r85822 ucrt)
#> Platform: x86_64-w64-mingw32/x64
#> Running under: Windows 10 x64 (build 19045)
#> 
#> Matrix products: default
#> 
#> 
#> locale:
#> [1] LC_COLLATE=C                          
#> [2] LC_CTYPE=English_United States.utf8   
#> [3] LC_MONETARY=English_United States.utf8
#> [4] LC_NUMERIC=C                          
#> [5] LC_TIME=English_United States.utf8    
#> 
#> time zone: America/Phoenix
#> tzcode source: internal
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] mlr3_0.18.0        lgr_0.4.4          animint2_2024.1.24 data.table_1.15.99
#> 
#> loaded via a namespace (and not attached):
#>  [1] future.apply_1.11.2      gtable_0.3.4             jsonlite_1.8.8          
#>  [4] highr_0.10               compiler_4.4.0           crayon_1.5.2            
#>  [7] rpart_4.1.23             Rcpp_1.0.12              stringr_1.5.1           
#> [10] parallel_4.4.0           jquerylib_0.1.4          globals_0.16.3          
#> [13] scales_1.3.0             uuid_1.2-0               RhpcBLASctl_0.23-42     
#> [16] yaml_2.3.8               fastmap_1.1.1            R6_2.5.1                
#> [19] plyr_1.8.9               mlr3tuning_0.19.2        labeling_0.4.3          
#> [22] knitr_1.46               palmerpenguins_0.1.1     backports_1.4.1         
#> [25] checkmate_2.3.1          future_1.33.2            munsell_0.5.1           
#> [28] paradox_0.11.1           bslib_0.7.0              mlr3measures_0.5.0      
#> [31] rlang_1.1.3              stringi_1.8.3            cachem_1.0.8            
#> [34] xfun_0.43                mlr3misc_0.15.0          sass_0.4.9              
#> [37] RJSONIO_1.3-1.9          cli_3.6.2                magrittr_2.0.3          
#> [40] digest_0.6.34            grid_4.4.0               bbotk_0.7.3             
#> [43] nc_2024.2.21             lifecycle_1.0.4          evaluate_0.23           
#> [46] glue_1.7.0               farver_2.1.1             listenv_0.9.1           
#> [49] codetools_0.2-19         parallelly_1.37.1        colorspace_2.1-0        
#> [52] reshape2_1.4.4           rmarkdown_2.26           mlr3resampling_2024.4.14
#> [55] tools_4.4.0              htmltools_0.5.8.1