`fit()`

with TensorFlowWhen you’re doing supervised learning, you can use `fit()`

and everything works smoothly.

When you need to take control of every little detail, you can write your own training loop entirely from scratch.

But what if you need a custom training algorithm, but you still want
to benefit from the convenient features of `fit()`

, such as
callbacks, built-in distribution support, or step fusing?

A core principle of Keras is **progressive disclosure of
complexity**. You should always be able to get into lower-level
workflows in a gradual way. You shouldn’t fall off a cliff if the
high-level functionality doesn’t exactly match your use case. You should
be able to gain more control over the small details while retaining a
commensurate amount of high-level convenience.

When you need to customize what `fit()`

does, you should
**override the training step function of the Model
class**. This is the function that is called by

`fit()`

for every batch of data. You will then be able to
call `fit()`

as usual – and it will be running your own
learning algorithm.Note that this pattern does not prevent you from building models with
the Functional API. You can do this whether you’re building
`Sequential`

models, Functional API models, or subclassed
models.

Let’s see how that works.

Let’s start from a simple example:

- We create a new class that subclasses
`Model`

. - We just override the method
`train_step(self, data)`

. - We return a dictionary mapping metric names (including the loss) to their current value.

The input argument `data`

is what gets passed to fit as
training data:

- If you pass arrays, by calling
`fit(x, y, ...)`

, then`data`

will be the list`(x, y)`

- If you pass a
`tf_dataset`

, by calling`fit(dataset, ...)`

, then`data`

will be what gets yielded by`dataset`

at each batch.

In the body of the `train_step()`

method, we implement a
regular training update, similar to what you are already familiar with.
Importantly, **we compute the loss via
self.compute_loss()**, which wraps the loss(es)
function(s) that were passed to

`compile()`

.Similarly, we call `metric$update_state(y, y_pred)`

on
metrics from `self$metrics`

, to update the state of the
metrics that were passed in `compile()`

, and we query results
from `self$metrics`

at the end to retrieve their current
value.

```
CustomModel <- new_model_class(
"CustomModel",
train_step = function(data) {
c(x, y = NULL, sample_weight = NULL) %<-% data
with(tf$GradientTape() %as% tape, {
y_pred <- self(x, training = TRUE)
loss <- self$compute_loss(y = y, y_pred = y_pred,
sample_weight = sample_weight)
})
# Compute gradients
trainable_vars <- self$trainable_variables
gradients <- tape$gradient(loss, trainable_vars)
# Update weights
self$optimizer$apply(gradients, trainable_vars)
# Update metrics (includes the metric that tracks the loss)
for (metric in self$metrics) {
if (metric$name == "loss")
metric$update_state(loss)
else
metric$update_state(y, y_pred)
}
# Return a dict mapping metric names to current value
metrics <- lapply(self$metrics, function(m) m$result())
metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
metrics
}
)
```

Let’s try this out:

```
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
model |> compile(optimizer = "adam", loss = "mse", metrics = "mae")
# Just use `fit` as usual
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> fit(x, y, epochs = 3)
```

```
## Epoch 1/3
## 32/32 - 1s - 19ms/step - loss: 2.9118 - mae: 1.3597
## Epoch 2/3
## 32/32 - 0s - 1ms/step - loss: 2.6026 - mae: 1.2856
## Epoch 3/3
## 32/32 - 0s - 1ms/step - loss: 2.3378 - mae: 1.2193
```

Naturally, you could just skip passing a loss function in
`compile()`

, and instead do everything *manually* in
`train_step`

. Likewise for metrics.

Here’s a lower-level example, that only uses `compile()`

to configure the optimizer:

- We start by creating
`Metric`

instances to track our loss and a MAE score (in`__init__()`

). - We implement a custom
`train_step()`

that updates the state of these metrics (by calling`update_state()`

on them), then query them (via`result()`

) to return their current average value, to be displayed by the progress bar and to be pass to any callback. - Note that we would need to call
`reset_states()`

on our metrics between each epoch! Otherwise calling`result()`

would return an average since the start of training, whereas we usually work with per-epoch averages. Thankfully, the framework can do that for us: just list any metric you want to reset in the`metrics`

property of the model. The model will call`reset_states()`

on any object listed here at the beginning of each`fit()`

epoch or at the beginning of a call to`evaluate()`

.

```
CustomModel <- new_model_class(
"CustomModel",
initialize = function(...) {
super$initialize(...)
self$loss_tracker <- metric_mean(name = "loss")
self$mae_metric <- metric_mean_absolute_error(name = "mae")
self$loss_fn <- loss_mean_squared_error()
},
train_step = function(data) {
c(x, y = NULL, sample_weight = NULL) %<-% data
with(tf$GradientTape() %as% tape, {
y_pred <- self(x, training = TRUE)
loss <- self$loss_fn(y, y_pred, sample_weight = sample_weight)
})
# Compute gradients
trainable_vars <- self$trainable_variables
gradients <- tape$gradient(loss, trainable_vars)
# Update weights
self$optimizer$apply(gradients, trainable_vars)
# Compute our own metrics
self$loss_tracker$update_state(loss)
self$mae_metric$update_state(y, y_pred)
# Return a dict mapping metric names to current value
list(
loss = self$loss_tracker$result(),
mae = self$mae_metric$result()
)
},
metrics = mark_active(function() {
# We list our `Metric` objects here so that `reset_states()` can be
# called automatically at the start of each epoch
# or at the start of `evaluate()`.
list(self$loss_tracker, self$mae_metric)
})
)
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
# We don't pass a loss or metrics here.
model |> compile(optimizer = "adam")
# Just use `fit` as usual
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> fit(x, y, epochs = 3)
```

```
## Epoch 1/3
## 32/32 - 1s - 16ms/step - loss: 2.6540 - mae: 1.2901
## Epoch 2/3
## 32/32 - 0s - 1ms/step - loss: 2.4139 - mae: 1.2303
## Epoch 3/3
## 32/32 - 0s - 1ms/step - loss: 2.2080 - mae: 1.1761
```

`sample_weight`

&
`class_weight`

You may have noticed that our first basic example didn’t make any
mention of sample weighting. If you want to support the
`fit()`

arguments `sample_weight`

and
`class_weight`

, you’d simply do the following:

- Unpack
`sample_weight`

from the`data`

argument - Pass it to
`compute_loss`

&`update_state`

(of course, you could also just apply it manually if you don’t rely on`compile()`

for losses & metrics) - That’s it.

```
CustomModel <- new_model_class(
"CustomModel",
train_step = function(data) {
c(x, y = NULL, sample_weight = NULL) %<-% data
with(tf$GradientTape() %as% tape, {
y_pred <- self(x, training = TRUE)
loss <- self$compute_loss(y = y, y_pred = y_pred,
sample_weight = sample_weight)
})
# Compute gradients
trainable_vars <- self$trainable_variables
gradients <- tape$gradient(loss, trainable_vars)
# Update weights
self$optimizer$apply_gradients(zip_lists(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
for (metric in self$metrics) {
if (metric$name == "loss") {
metric$update_state(loss)
} else {
metric$update_state(y, y_pred, sample_weight = sample_weight)
}
}
# Return a dict mapping metric names to current value
metrics <- lapply(self$metrics, function(m) m$result())
metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
metrics
}
)
# Construct and compile an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, units = 1)
model <- CustomModel(inputs, outputs)
model |> compile(optimizer = "adam", loss = "mse", metrics = "mae")
# You can now use sample_weight argument
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
sw <- random_normal(c(1000, 1))
model |> fit(x, y, sample_weight = sw, epochs = 3)
```

```
## Epoch 1/3
## 32/32 - 1s - 18ms/step - loss: 0.1607 - mae: 1.3018
## Epoch 2/3
## 32/32 - 0s - 1ms/step - loss: 0.1452 - mae: 1.2999
## Epoch 3/3
## 32/32 - 0s - 1ms/step - loss: 0.1335 - mae: 1.2986
```

What if you want to do the same for calls to
`model.evaluate()`

? Then you would override
`test_step`

in exactly the same way. Here’s what it looks
like:

```
CustomModel <- new_model_class(
"CustomModel",
test_step = function(data) {
# Unpack the data
c(x, y = NULL, sw = NULL) %<-% data
# Compute predictions
y_pred = self(x, training = FALSE)
# Updates the metrics tracking the loss
self$compute_loss(y = y, y_pred = y_pred, sample_weight = sw)
# Update the metrics.
for (metric in self$metrics) {
if (metric$name != "loss") {
metric$update_state(y, y_pred, sample_weight = sw)
}
}
# Return a dict mapping metric names to current value.
# Note that it will include the loss (tracked in self.metrics).
metrics <- lapply(self$metrics, function(m) m$result())
metrics <- setNames(metrics, sapply(self$metrics, function(m) m$name))
metrics
}
)
# Construct an instance of CustomModel
inputs <- keras_input(shape = 32)
outputs <- layer_dense(inputs, 1)
model <- CustomModel(inputs, outputs)
model |> compile(loss = "mse", metrics = "mae")
# Evaluate with our custom test_step
x <- random_normal(c(1000, 32))
y <- random_normal(c(1000, 1))
model |> evaluate(x, y)
```

`## 32/32 - 0s - 8ms/step - loss: 0.0000e+00 - mae: 1.3947`

```
## $loss
## [1] 0
##
## $mae
## [1] 1.394695
```

Let’s walk through an end-to-end example that leverages everything you just learned.

Let’s consider:

- A generator network meant to generate 28x28x1 images.
- A discriminator network meant to classify 28x28x1 images into two classes (“fake” and “real”).
- One optimizer for each.
- A loss function to train the discriminator.

```
# Create the discriminator
discriminator <-
keras_model_sequential(name = "discriminator", input_shape = c(28, 28, 1)) |>
layer_conv_2d(filters = 64, kernel_size = c(3, 3),
strides = c(2, 2), padding = "same") |>
layer_activation_leaky_relu(negative_slope = 0.2) |>
layer_conv_2d(filters = 128, kernel_size = c(3, 3),
strides = c(2, 2), padding = "same") |>
layer_activation_leaky_relu(negative_slope = 0.2) |>
layer_global_max_pooling_2d() |>
layer_dense(units = 1)
# Create the generator
latent_dim <- 128
generator <-
keras_model_sequential(name = "generator", input_shape = latent_dim) |>
layer_dense(7 * 7 * 128) |>
layer_activation_leaky_relu(negative_slope = 0.2) |>
layer_reshape(target_shape = c(7, 7, 128)) |>
layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4),
strides = c(2, 2), padding = "same") |>
layer_activation_leaky_relu(negative_slope = 0.2) |>
layer_conv_2d_transpose(filters = 128, kernel_size = c(4, 4),
strides = c(2, 2), padding = "same") |>
layer_activation_leaky_relu(negative_slope = 0.2) |>
layer_conv_2d(filters = 1, kernel_size = c(7, 7), padding = "same",
activation = "sigmoid")
```

Here’s a feature-complete GAN class, overriding
`compile()`

to use its own signature, and implementing the
entire GAN algorithm in 17 lines in `train_step`

:

```
GAN <- Model(
classname = "GAN",
initialize = function(discriminator, generator, latent_dim, ...) {
super$initialize(...)
self$discriminator <- discriminator
self$generator <- generator
self$latent_dim <- as.integer(latent_dim)
self$d_loss_tracker <- metric_mean(name = "d_loss")
self$g_loss_tracker <- metric_mean(name = "g_loss")
},
compile = function(d_optimizer, g_optimizer, loss_fn, ...) {
super$compile(...)
self$d_optimizer <- d_optimizer
self$g_optimizer <- g_optimizer
self$loss_fn <- loss_fn
},
metrics = active_property(function() {
list(self$d_loss_tracker, self$g_loss_tracker)
}),
train_step = function(real_images) {
# Sample random points in the latent space
batch_size <- shape(real_images)[[1]]
random_latent_vectors <-
tf$random$normal(shape(batch_size, self$latent_dim))
# Decode them to fake images
generated_images <- self$generator(random_latent_vectors)
# Combine them with real images
combined_images <- op_concatenate(list(generated_images,
real_images))
# Assemble labels discriminating real from fake images
labels <- op_concatenate(list(op_ones(c(batch_size, 1)),
op_zeros(c(batch_size, 1))))
# Add random noise to the labels - important trick!
labels %<>% `+`(tf$random$uniform(shape(.), maxval = 0.05))
# Train the discriminator
with(tf$GradientTape() %as% tape, {
predictions <- self$discriminator(combined_images)
d_loss <- self$loss_fn(labels, predictions)
})
grads <- tape$gradient(d_loss, self$discriminator$trainable_weights)
self$d_optimizer$apply_gradients(
zip_lists(grads, self$discriminator$trainable_weights))
# Sample random points in the latent space
random_latent_vectors <-
tf$random$normal(shape(batch_size, self$latent_dim))
# Assemble labels that say "all real images"
misleading_labels <- op_zeros(c(batch_size, 1))
# Train the generator (note that we should *not* update the weights
# of the discriminator)!
with(tf$GradientTape() %as% tape, {
predictions <- self$discriminator(self$generator(random_latent_vectors))
g_loss <- self$loss_fn(misleading_labels, predictions)
})
grads <- tape$gradient(g_loss, self$generator$trainable_weights)
self$g_optimizer$apply_gradients(
zip_lists(grads, self$generator$trainable_weights))
list(d_loss = d_loss, g_loss = g_loss)
}
)
```

Let’s test-drive it:

```
batch_size <- 64
c(c(x_train, .), c(x_test, .)) %<-% dataset_mnist()
all_digits <- op_concatenate(list(x_train, x_test))
all_digits <- op_reshape(all_digits, c(-1, 28, 28, 1))
dataset <- all_digits |>
tfdatasets::tensor_slices_dataset() |>
tfdatasets::dataset_map(\(x) op_cast(x, "float32") / 255) |>
tfdatasets::dataset_shuffle(buffer_size = 1024) |>
tfdatasets::dataset_batch(batch_size = batch_size)
gan <- GAN(discriminator = discriminator,
generator = generator,
latent_dim = latent_dim)
gan |> compile(
d_optimizer = optimizer_adam(learning_rate = 0.0003),
g_optimizer = optimizer_adam(learning_rate = 0.0003),
loss_fn = loss_binary_crossentropy(from_logits = TRUE)
)
# To limit the execution time, we only train on 100 batches. You can train on
# the entire dataset. You will need about 20 epochs to get nice results.
gan |> fit(
tfdatasets::dataset_take(dataset, 100),
epochs = 1
)
```

`## 100/100 - 5s - 47ms/step - d_loss: 0.0000e+00 - g_loss: 0.0000e+00`

The ideas behind deep learning are simple, so why should their implementation be painful?