Custom training a sub-classed model with custom callbacks and checkpoints using Tensorflow 2.0
In this tutorial we have discussed on how to create a sub-classed extending keras.Model class and trained it using a low-level training loop. In addition, we have demonstrated how to define and apply custom callbacks and checkpoints to the custom training loop and finally how to save and restore the model. While you use model.compile() and model.fit() to train a model, everything is straightforward. However, custom training loop is a little more complicated. For this reason we decided to write this article. Here are the links to the source code on Colab and Github:
Layers
First of all, we have implemented the custom layers of the network . The layers that we have designed differ slightly from what the original paper described. This network was designed just for tutorial purposes and you can change the layers and network architecture as you wish. The network's custom layers are as follows:
All of modules mentioned above have been inherited from keras.layers.Layer class. Their code snippets are as follows:
ConvModule
class ConvModule(Layer)
"""Convolution Module
Input Args:
num_channels: number of convolution layer channels.
kernel_size: the Convolution layer kernel size
stride: the step size convolution filter moves.
padding: padding type (e.g `valid`, `same`).
"""
def __init__(self, num_channels, kernel_size, strides, padding='same', **kwargs):
"""Constructor method."""
super(ConvModule, self).__init__(**kwargs)
self.num_channels = num_channels
self.kernel_size = kernel_size
self.strides = strides
self.padding = padding
# Convolution layer
self.conv = Conv2D(
filters=self.num_channels, kernel_size=self.kernel_size,
strides=self.strides, padding=self.padding
)
# Batch Normalization layer
self.bn = BatchNormalization()
def call(self, input_tensor, training=False):
# Forward pass.
x = self.conv(input_tensor)
x = self.bn(x, training=training)
x = tf.nn.relu(x)
return x
def get_config(self):
config = super(ConvModule, self).get_config().copy()
config.update({
'num_channels': self.num_channels,
'kernel_size': self.kernel_size,
'strides': self.strides,
'padding': self.padding
})
return config
def build_graph(self, input_shape):
# Initialize custom layer
x = Input(shape=input_shape)
return Model(inputs=[x], outputs=self.call(x))
Inception Module
class InceptionModule(Layer)
"""
Inception Module contains two convolution module with 1x1 and 3x3
kernel size and merge their output
Arguments:
kernel_size1x1: number channels of conv module which has 1x1 kernel
kernel_size3x3: number channels of conv module with 3x3 kernel
"""
def __init__(self, kernel_size1x1, kernel_size3x3, **kwargs):
# Constructor method
super(InceptionModule, self).__init__(**kwargs)
self.kernel_size1x1 = kernel_size1x1
self.kernel_size3x3 = kernel_size3x3
# Two Conv Module, they will take same input tensor
self.conv1 = ConvModule(self.kernel_size1x1, kernel_size=(1, 1), strides=(1, 1))
self.conv2 = ConvModule(self.kernel_size3x3, kernel_size=(3, 3), strides=(1, 1))
self.concat = keras.layers.Concatenate()
def call(self, input_tensor, training=False):
# Forward pass
x_1x1 = self.conv1(input_tensor)
x_3x3 = self.conv2(input_tensor)
x = self.concat([x_1x1, x_3x3])
return x
def get_config(self):
# Take base config
config = super(InceptionModule, self).get_config().copy()
# update by subclass params
config.update({
'kernel_size1x1': self.kernel_size1x1,
'kernel_size3x3': self.kernel_size3x3
})
return config
def build_graph(self, input_shape):
# Initialize layer
x = Input(shape=input_shape)
return Model(inputs=[x], outputs=self.call(x))
DownSampleModule
class DownSampleModule(Layer)
"""DownSample module includes a ConvModule with 3x3 kernel and
a MaxPooling layer with 3x3 pool size and 2x2 strides.
"""
def __init__(self, num_channels, **kwargs):
super(DownSampleModule, self).__init__(**kwargs)
self.channels = num_channels
# Convolution layer
self.conv = ConvModule(self.channels, kernel_size=(3, 3),
strides=(2, 2), padding='valid')
# Max pooling layer
self.pool = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))
# Concatenation layer
self.concat = keras.layers.Concatenate()
def call(self, input_tensor, training=False):
# Forward pass
conv_y = self.conv(input_tensor, training=training)
pool_y = self.pool(input_tensor)
return self.concat([conv_y, pool_y])
def get_config(self):
config = super(DownSampleModule, self).get_config().copy()
config.update({'num_channels': self.channels})
return config
def build_graph(self, input_shape):
x = Input(shape=input_shape)
return Model(inputs=[x], outputs=self.call(x))
Mini Inception model architecture
class MiniInception(Model)
"""Contains 8 Inception Module, 2 DownSample and 1 Dense layer."""
def __init__(self, num_classes=10, **kwargs):
"""Constructor"""
super(MiniInception, self).__init__(**kwargs)
self.num_classes = num_classes
# the first conv module
self.conv_block = ConvModule(96, (3, 3), (1, 1))
self.inception_1 = InceptionModule(32, 32)
self.inception_2 = InceptionModule(32, 48)
self.downsample_1 = DownSampleModule(80)
self.inception_3 = InceptionModule(112, 48)
self.inception_4 = InceptionModule(96, 64)
self.inception_5 = InceptionModule(80, 80)
self.inception_6 = InceptionModule(48, 96)
self.downsample_2 = DownSampleModule(96)
self.inception_7 = InceptionModule(176, 160)
self.inception_8 = InceptionModule(176, 160)
self.avg_pool = AveragePooling2D((7, 7))
self.flat = Flatten()
self.classifier = Dense(num_classes, activation='softmax')
def call(self, input_tensor, training=False):
# Forward pass
x = self.conv_block(input_tensor)
x = self.inception_1(x)
x = self.inception_2(x)
x = self.downsample_1(x)
x = self.inception_3(x)
x = self.inception_4(x)
x = self.inception_5(x)
x = self.inception_6(x)
x = self.downsample_2(x)
x = self.inception_7(x)
x = self.inception_8(x)
x = self.avg_pool(x)
x = self.flat(x)
return self.classifier(x)
def build_graph(self, input_shape):
x = Input(shape=input_shape)
return Model(inputs=[x], outputs=self.call(x))
def get_config(self):
config = super().get_config().copy()
config.update({
'num_classes': self.num_classes
})
return config
Load and pre-process CIFAR-10 dataset
import matplotlib.pyplot as pl
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
# target class name
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(10, 10))
for i in range(25):
plt.subplot(5, 5, i+1)
plt.imshow(x_train[i])
plt.title(class_names[y_train[i][0]])
plt.axis('off')
Encode the labels to one-hot representations and prepare datasets
We have an input shape of 32x32x3 and total of 10 classes to classify.
# One-hot representation of the output
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)
batch_size = 64
# Prepare datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1024).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)
Training mechanism
We have two options to train and evaluate our networks in tf.keras. In a convenient way we can train and evaluate by fit and evaluate functions, but in other way we can leverage the low-level over the training and evaluation processes. In the latter, it's required to implement custom training and evaluation loops from scratch. Here are the steps:
Inside train step according to its name we have implemented training algorithm for each batch of dataset, i.e. it is the basis of our training function. GradientTape API provides automatic and convenient differentiation in Tensorflow. The API allows you compute some function outputs gradient with respect to some inputs effortlessly. Inside the GradientTape() scope, output (which basically would be model output and loss) are measured and outside the scope we retrieve the gradient of the model loss with respect to the model trainable weights. This is what we have done in our train_step function. For more information you can see this page. Here is the train_step and test_step function:
@tf.function
def train_step(x, y, model, loss_object, optimizer, train_metric):
"""train step for each dataset batch.
input: x, y typically batches
step: batch step
return: loss value for each batch
"""
# start the scope of gradient
with tf.GradientTape() as tape:
logits = model(x, training=True) # forward pass
loss_value = loss_object(y, logits) # compute loss
# Compute gradients
grads = tape.gradient(loss_value, model.trainable_weights)
# update wweights
optimizer.apply_gradients(zip(grads, model.trainable_weights))
# update metric
if train_metric is not None:
train_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y, model, loss_fn, test_metric):
"""Evaluate the model for each batch step
inputs: x, y typically batches
step: batch step
model: the NN model
loss_fn: the loss function
test_metric: test metric e.g. Accuracy
writer: tensorboard file writer
return: loss value
"""
val_logits = model(x, training=False)
# Compute loss value
loss_value = loss_fn(y, val_logits)
# update metrics
if test_metric is not None:
test_metric.update_state(y, val_logits)
return loss_value
Eager execution is turned on by default in Tensorflow 2. tf.function makes a python-independent dataflow graph out of python code which helps you create performant models. but it has side effects on tf.Variable() too. It only supports singleton tf.Variables. When you defined your train_step function using tf.function take care not to call it with two or more different optimizer or network, because optimizer creates tf.Variable inside tf.function, thus you will get the following error. In this page you can find detailed information about tf.function.
In the next step you can define checkpoint and callbacks for your training algorithm if needed, but it is optional and ignorable.
Checkpoint
tip: If you notice we have defined using tf.train.Checkpoint class not using tf.keras.callbacks.ModelCheckpoint, since we wouldn't train by fit() method.
# Checkpoin
ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model,
optimizer=optimizer,
iterator=iter(train_dataset))
# Checkpoint manager
manager = tf.train.CheckpointManager(ckpt, './tf_ckpt', max_to_keep=1)
We have used the following code snippet to check if there is a checkpoint to restore before iterate through epochs.
# restore latest checkpoint if an
checkpoint.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
print("Loaded from {}".format(ckpt_manager.latest_checkpoint))
else:
print("Initialize from scratch.")
If you would like to define a custom callback in your training process, you must define a subclass inherited from base class of tf.keras.callbacks.Callback. Callback class as like of all other classes has a constructor as __init__(), moreover it has lots of other methods that you should override which one as you need. Some of Callback methods are:
In the following we have implemented two custom callbacks for custom learning rate schedule and custom early stopping.
Recommended by LinkedIn
class CustomLRScheduler(Callback)
"""Learning rate Scheduler which set learning rate according to schedule.
Arguments:
schedule: a function that takes an epoch index and current learning rate
as inputs and returns a new learning rate as output.
opt_lr: optimizer because we cannot reach self.model.optimizer
when training model using custom training loop.
"""
def __init__(self, schedule, opt=None):
super(CustomLRScheduler, self).__init__()
self.schedule = schedule
self.optimizer = opt
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.optimizer, 'lr'):
raise ValueError("Optimizer must have a `lr` attribute. ")
lr = float(keras.backend.get_value(self.optimizer.lr))
scheduled_lr = self.schedule(epoch, lr)
keras.backend.set_value(self.optimizer.lr, scheduled_lr)
print("\nEpoch: %05d learning rate is %6.4f" % (epoch, scheduled_lr))
LR_SCHEDULE = [
(1, 0.01),
(10, 0.005),
(15, 0.001),
(20, 5e-4),
]
def custom_schedule(epoch, lr):
"""Helper function to return scheduled learning rate based on epoch."""
for i in range(len(LR_SCHEDULE)):
if epoch == LR_SCHEDULE[i][0]:
return LR_SCHEDULE[i][1]
return lr
class CustomEarlyStopping(Callback):
"""Stop training when monitor is at best.
Arguments:
model: model under train because self.model is not available
through custom training loop
patience: the number of epochs training can wait since get worse
monitor: parameter under monitoring
"""
def __init__(self, model, patience=0, monitor='loss'):
super(CustomEarlyStopping, self).__init__()
self.model = model
self.patience = patience
# best weights to store the weights at minimum loss score
self.best_weights = None
# monitor the parameter to stop early
self.monitor = monitor
def on_train_begin(self, logs=None):
# the number of epoch it has waited when loss is no longer minimum
self.wait = 0
# the epoch the training stops at
self.stopped_epoch = 0
if 'loss' in self.monitor:
# Initialize the best as infinity
self.best = np.Inf
# set monitor operation
self.monitor_op = np.less
else:
self.best = -np.Inf
self.monitor_op = np.greater
def on_epoch_end(self, epoch, logs=None):
current = logs.get(self.monitor)
if self.monitor_op(current, self.best):
self.best = current
self.wait = 0
# Record the best waits if current results is better
self.best_weights = self.model.get_weights()
else:
self.wait += 1
if self.wait >= self.patience:
self.stopped_epoch = epoch
self.model.stop_training = True
print('Restore model weights from the end of the best epoch.')
self.model.set_weights(self.best_weights)
def on_train_end(self, logs=None):
if self.stopped_epoch > 0:
print("Epoch %05d early stopping" % (self.stopped_epoch))
# Create callback objects
schedule_learning_rate = CustomLRScheduler(custom_schedule, optimizer)
stop = CustomEarlyStopping(model, patience=3, monitor='val_loss')
You can create a callback list using the code below and then apply all callback method to all callbacks with just one line of code.
callbacks = keras.callbacks.CallbackList(_callbacks, add_history=True, model=model)
The important tip here is where to call which method. If you pay attention to the names of methods you quickly find out where you have to call the method, e.g. you must call the on_train_begin method before training starts in other words exactly before iterating over epochs or you must call the on_train_end method after epoch loop terminates. You should call on_epoch_begin method in the beginning of the epoch, right after the loop over epoch and so on. The code snippet bellow demonstrates it.
callbacks = tf.keras.callbacks.CallbackList([...]
callbacks.append(...)
callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
callbacks.on_epoch_begin(epoch)
for i, data in dataset.enumerate():
callbacks.on_train_batch_begin(i)
batch_logs = model.train_step(data)
callbacks.on_train_batch_end(i, batch_logs)
epoch_logs = ...
callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs))
That's our main train function implemented from scratch. You can see all together as follows:
def train(train_dataset, max_epoch, model, loss_object,
optimizer, validation_dataset=None, train_metric=None,
test_metric=None, **kwargs):
"""Custom training loop from scratch."""
# Initialize
best = np.Inf
save_best_only = False
save_freq = None
# define train and validation loss mean
train_loss_mean = keras.metrics.Mean()
val_loss_mean = keras.metrics.Mean()
# print = for each step
if len(train_dataset) > 30:
progress_step = np.ceil(len(train_dataset) / 30)
else:
progress_step = 1
# Check if model checkpoint exists
if 'model_checkpoint' in kwargs:
model_checkpoint = kwargs.pop('model_checkpoint')
if 'checkpoint' in model_checkpoint:
checkpoint = model_checkpoint.pop('checkpoint')
else:
checkpoint = tf.train.Checkpoint(step=tf.Variable(1), net=model, optimizer=optimizer, iterator=iter(train_dataset))
if 'manager' in model_checkpoint:
ckpt_manager = model_checkpoint.pop('manager')
else:
ckpt_manager = tf.train.CheckpointManager(checkpoint, './tf_ckpt', max_to_keep=3)
if 'monitor' in model_checkpoint:
if 'loss' in model_checkpoint['monitor']:
best *= 1
else:
best *= -1
else:
model_checkpoint['monitor'] = 'val_loss'
if not 'mode' in model_checkpoint or model_checkpoint['mode'] == 'auto':
if 'loss' in model_checkpoint['monitor']:
model_checkpoint.update({'mode': 'min', 'monitor_op': np.less})
else:
model_checkpoint.update({'mode': 'max', 'monitor_op': np.greater})
if not 'save_best_only' in model_checkpoint:
model_checkpoint.update({'save_best_only': False})
save_best_only = model_checkpoint.pop('save_best_only')
if not 'save_freq' in model_checkpoint:
model_checkpoint['save_freq'] = 'epoch'
save_freq = model_checkpoint.get('save_freq')
# restore latest checkpoint if any
checkpoint.restore(ckpt_manager.latest_checkpoint)
if ckpt_manager.latest_checkpoint:
print("Loaded from {}".format(ckpt_manager.latest_checkpoint))
else:
print("Initialize from scratch.")
# check if tensorboard file writer exist
if 'train_writer' in kwargs:
train_writer = kwargs.pop('train_writer')
else:
train_writer = None
if 'test_writer' in kwargs:
test_writer = kwargs.pop('test_writer')
else:
test_writer = None
if 'callbacks' in kwargs:
_callbacks = kwargs.pop('callbacks')
callbacks = keras.callbacks.CallbackList(_callbacks, add_history=True, model=model)
else:
callbacks = keras.callbacks.CallbackList([])
# Call callbacks when train begins
log = {}
callbacks.on_train_begin(logs=log)
# Main train loop
for epoch in range(1, max_epoch+1):
# Call callbacks when epoch begin
callbacks.on_epoch_begin(epoch, logs=log)
start = time.time()
print(f"Epoch:{epoch} [>", end='')
# Iterate over the batchs of train dataset
for step, (x_train_batch, y_train_batch) in enumerate(train_dataset):
# step = tf.convert_to_tensor(step, dtype=tf.int64)
train_loss_value = train_step(x_train_batch, y_train_batch,
model, loss_object, optimizer,
train_metric)
train_loss_mean.update_state(train_loss_value)
if save_freq is not None and isinstance(save_freq, int):
checkpoint.step.assign_add(1)
if int(checkpoint.step) % save_freq == 0:
save_path = ckpt_manager.save()
print("Saved checkpoint at step {}: {}".format(checkpoint.step, save_path))
elif (step + 1) % progress_step == 0:
print('\b=>', end='')
# write train loss and accuracy to the tensorboard
if train_writer is not None:
with train_writer.as_default():
tf.summary.scalar('loss', train_loss_mean.result(), step=epoch)
if train_metric is not None:
tf.summary.scalar('accuracy', train_metric.result(), step=epoch)
# evaluation
for x_val_batch, y_val_batch in test_dataset:
val_loss_value = test_step(x_val_batch, y_val_batch, model,
loss_object, test_metric)
val_loss_mean(val_loss_value)
# write test loss and accuracy to the tensorboard
if test_writer is not None:
with test_writer.as_default():
tf.summary.scalar('val_loss', val_loss_mean.result(), step=epoch)
if test_metric is not None:
tf.summary.scalar('val_accuracy', test_metric.result(), step=epoch)
template = "\b] ETA: {:.2f}, loss: {:.4f}, acc: {:.4f} | val loss: {:.4f}, val acc: {:.4f}"
print(template.format(
(time.time() - start)/60, train_loss_mean.result(),
train_metric.result(), val_loss_mean.result(),
test_metric.result()
))
if model_checkpoint['monitor'] == 'loss':
current = train_loss_mean.result()
elif model_checkpoint['monitor'] == 'val_loss':
current = val_loss_mean.result()
elif model_checkpoint['monitor'] == 'acc' or model_checkpoint['monitor'] == 'accuracy':
current = train_metric.result()
elif model_checkpoint['monitor'] == 'val_acc' or model_checkpoint['monitor'] == 'val_accuracy':
current = test_metric.result()
if save_freq is not None and save_freq == 'epoch':
checkpoint.step.assign_add(1)
# save just best model
if save_best_only:
if model_checkpoint.get('monitor_op')(current, best):
save_path = ckpt_manager.save()
print("Improved {} from {} to {}".format(model_checkpoint.get('monitor'), best, current))
print("\t\tSaved checkpoint: {}".format(checkpoint.step, save_path))
best = current
# model.save("best-{}".format(model.__class__.__name__),
# save_format='tf')
# best_loss = val_loss_mean.result()
# print("model saved")
else:
save_path = ckpt_manager.save()
print("Saved checkpoint: {}\n".format(save_path))
# Call callbacks when epoch ends
log = dict(
loss=train_loss_mean.result(),
val_loss=val_loss_mean.result(),
acc=train_metric.result(),
val_acc=test_metric.result()
)
callbacks.on_epoch_end(epoch, logs=log)
# Reset metrics at the end of each epoch
train_metric.reset_states()
test_metric.reset_states()
train_loss_mean.reset_states()
val_loss_mean.reset_states()
# Call callbacks when train end
callbacks.on_train_end(logs={})
At last, call train function by defined checkpoint and callbacks:
model_checkpoint =
'checkpoint': ckpt, 'manager': manager, 'monitor': 'val_loss',
'mode': 'min', 'save_freq': 'epoch', 'save_best_only': False
}
epochs = 30
train(train_dataset, epochs, model, loss_object, optimizer, test_dataset, train_metric,
test_metric, train_writer=train_writer, test_writer=test_writer,
model_checkpoint=model_checkpoint,
callbacks=[schedule_learning_rate, stop]
)
If you rerun the train function it will strat from the last saved checkpoint, not from scratch.
We have defined file writer using tf.summary in order to save loss and accuracy values to represent later by Tensorboard.
# tensorboard write
train_writer = tf.summary.create_file_writer('logs/train/')
test_writer = tf.summary.create_file_writer('logs/test/')
We have saved the loss and accuracy values for train and evaluation by the code snippet as below:
for step, (x_train_batch, y_train_batch) in enumerate(train_dataset):
# step = tf.convert_to_tensor(step, dtype=tf.int64)
train_loss_value = train_step(x_train_batch, y_train_batch,
model, loss_object, optimizer,
train_metric)
train_loss_mean.update_state(train_loss_value)
# write train loss and accuracy to the tensorboard
with train_writer.as_default():
tf.summary.scalar('loss', train_loss_mean.result(), step=epoch)
tf.summary.scalar('accuracy', train_metric.result(), step=epoch)
for x_val_batch, y_val_batch in test_dataset:
val_loss_value = test_step(x_val_batch, y_val_batch, model,
loss_object, test_metric)
val_loss_mean(val_loss_value)
# write test loss and accuracy to the tensorboard
with test_writer.as_default():
tf.summary.scalar('val_loss', val_loss_mean.result(), step=epoch)
tf.summary.scalar('val_accuracy', test_metric.result(), step=epoch)
Afterwards, we can use Tensorboard to display loss and accuracy values during training and validation process.
%load_ext tensorboar
%tensorboard --logdir logs
Last but not least, you cannot save the whole model neither in SavedModel nor HDF5 format if your model implemented by sub-classing API. You can save the model entire if your model defined by Sequential or Functional APIs, otherwise we can save just model weights. Therefore we can reload the model if only the code is available.
# save just model weights
model.save_weights(f"./model/{model.__class__.__name__}-weights")
Reloading a model designed using a subclass API is more complicated. The solution was found in an awesome article written by François Chollet regarding saving and serializing models in TF.Keras that must read. It has three stages as follows:
# Restore model i.e model weights
loaded = MiniInception()
loaded.compile(
loss=keras.losses.CategoricalCrossentropy(),
optimizer=keras.optimizers.Adam(),
metrics=keras.metrics.CategoricalAccuracy()
)
x_batch, y_batch = next(iter(train_dataset))
print(x_batch.shape, y_batch.shape)
loaded.train_on_batch(x_batch, y_batch)
loaded.load_weights("./model/{}-weights".format(loaded.__class__.__name__))
Conclusion
In this tutorial we have explained how to implement a sub-classed model and train it using low-level custom training loop. Moreover, we have illustrated how to implement custom checkpoint and callbacks for such training mechanism. We tried to put all custom arguments together and save and reload this fully customized model excluding loss function. To customize loss you can define a custom loss function or define a sub-class extending keras.losses.Loss class. I hope to do that in next parts. Keep in mind that never try to save a whole model while defined it as a sub-class, because Sequential and Functional models are data structures, but a sub-classed model is a piece of code and to reload it you need the code that created it.
REFERENCES:
https://towardsdatascience.com/model-sub-classing-and-custom-training-loop-from-scratch-in-tensorflow-2-cc1d4f10fb4
https://www.tensorflow.org/guide/checkpoint
https://www.tensorflow.org/guide/keras/custom_callback
https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback
https://medium.com/smart-iot/custom-training-with-custom-callbacks-3bcd117a8f7e
https://colab.research.google.com/drive/172D4jishSgE3N7AO6U2OKAA_0wNnrMOq#scrollTo=KOKNBojtsl0F