Custom training a sub-classed model with custom callbacks and checkpoints using Tensorflow 2.0

Custom training a sub-classed model with custom callbacks and checkpoints using Tensorflow 2.0

In this tutorial we have discussed on how to create a sub-classed extending keras.Model class and trained it using a low-level training loop. In addition, we have demonstrated how to define and apply custom callbacks and checkpoints to the custom training loop and finally how to save and restore the model. While you use model.compile() and model.fit() to train a model, everything is straightforward. However, custom training loop is a little more complicated. For this reason we decided to write this article. Here are the links to the source code on Colab and Github:

colab code github code

Layers

First of all, we have implemented the custom layers of the network . The layers that we have designed differ slightly from what the original paper described. This network was designed just for tutorial purposes and you can change the layers and network architecture as you wish. The network's custom layers are as follows:

  • ConvModule: comprised of a Convolution, Batch Normalization and an Activation layer.
  • InceptionModule: composed two ConvModule with 1x1 and 3x3 kernel size and a Concatenate layer
  • DownSampleModule: consists of a ConvModule and a MaxPooling layer, both have strides of 2.

All of modules mentioned above have been inherited from keras.layers.Layer class. Their code snippets are as follows:

ConvModule

class ConvModule(Layer)
    """Convolution Module
    Input Args: 
        num_channels: number of convolution layer channels.
        kernel_size: the Convolution layer kernel size
        stride: the step size convolution filter moves.
        padding: padding type (e.g `valid`, `same`).
        """
    def __init__(self, num_channels, kernel_size, strides, padding='same', **kwargs):
        """Constructor method."""
        super(ConvModule, self).__init__(**kwargs)
        self.num_channels = num_channels
        self.kernel_size = kernel_size
        self.strides = strides
        self.padding = padding
        # Convolution layer
        self.conv = Conv2D(
            filters=self.num_channels, kernel_size=self.kernel_size, 
            strides=self.strides, padding=self.padding
        )
        # Batch Normalization layer
        self.bn = BatchNormalization()


    def call(self, input_tensor, training=False):
        # Forward pass.
        x = self.conv(input_tensor)
        x = self.bn(x, training=training)
        x = tf.nn.relu(x)
        return x
    
    def get_config(self):
        config = super(ConvModule, self).get_config().copy()
        config.update({
            'num_channels': self.num_channels, 
            'kernel_size': self.kernel_size,
            'strides': self.strides,
            'padding': self.padding
        })
        return config


    def build_graph(self, input_shape):
        # Initialize custom layer
        x = Input(shape=input_shape)
        return Model(inputs=[x], outputs=self.call(x))        
No alt text provided for this image

Inception Module

class InceptionModule(Layer)
    """
    Inception Module contains two convolution module with 1x1 and 3x3 
    kernel size and merge their output
    Arguments:
        kernel_size1x1: number channels of conv module which has 1x1 kernel
        kernel_size3x3: number channels of conv module with 3x3 kernel
    """
    def __init__(self, kernel_size1x1, kernel_size3x3, **kwargs):
        # Constructor method
        super(InceptionModule, self).__init__(**kwargs)
        self.kernel_size1x1 = kernel_size1x1
        self.kernel_size3x3 = kernel_size3x3


        # Two Conv Module, they will take same input tensor
        self.conv1 = ConvModule(self.kernel_size1x1, kernel_size=(1, 1), strides=(1, 1))
        self.conv2 = ConvModule(self.kernel_size3x3, kernel_size=(3, 3), strides=(1, 1))
        self.concat = keras.layers.Concatenate()


    def call(self, input_tensor, training=False):
        # Forward pass
        x_1x1 = self.conv1(input_tensor)
        x_3x3 = self.conv2(input_tensor)
        x = self.concat([x_1x1, x_3x3])
        return x


    def get_config(self):
        # Take base config 
        config = super(InceptionModule, self).get_config().copy()
        # update by subclass params
        config.update({
            'kernel_size1x1': self.kernel_size1x1,
            'kernel_size3x3': self.kernel_size3x3
        })
        return config
    
    def build_graph(self, input_shape):
        # Initialize layer
        x = Input(shape=input_shape)
        return Model(inputs=[x], outputs=self.call(x))        
No alt text provided for this image

DownSampleModule


class DownSampleModule(Layer)
    """DownSample module includes a ConvModule with 3x3 kernel and 
    a MaxPooling layer with 3x3 pool size and 2x2 strides.
    """
    
    def __init__(self, num_channels, **kwargs):
        super(DownSampleModule, self).__init__(**kwargs)
        self.channels = num_channels
        # Convolution layer
        self.conv = ConvModule(self.channels, kernel_size=(3, 3), 
                               strides=(2, 2), padding='valid')
        # Max pooling layer
        self.pool = MaxPooling2D(pool_size=(3, 3), strides=(2, 2))
        # Concatenation layer
        self.concat = keras.layers.Concatenate()


    def call(self, input_tensor, training=False):
        # Forward pass
        conv_y = self.conv(input_tensor, training=training)
        pool_y = self.pool(input_tensor)


        return self.concat([conv_y, pool_y])


    def get_config(self):
        config = super(DownSampleModule, self).get_config().copy()
        config.update({'num_channels': self.channels})
        return config


    def build_graph(self, input_shape):
        x = Input(shape=input_shape)
        return Model(inputs=[x], outputs=self.call(x))        
No alt text provided for this image

Mini Inception model architecture


class MiniInception(Model)
    """Contains 8 Inception Module, 2 DownSample and 1 Dense layer."""
    def __init__(self, num_classes=10, **kwargs):
        """Constructor"""
        super(MiniInception, self).__init__(**kwargs)


        self.num_classes = num_classes
        # the first conv module
        self.conv_block = ConvModule(96, (3, 3), (1, 1))


        self.inception_1 = InceptionModule(32, 32)
        self.inception_2 = InceptionModule(32, 48)
        self.downsample_1 = DownSampleModule(80)


        self.inception_3 = InceptionModule(112, 48)
        self.inception_4 = InceptionModule(96, 64)
        self.inception_5 = InceptionModule(80, 80)
        self.inception_6 = InceptionModule(48, 96)
        self.downsample_2 = DownSampleModule(96)


        self.inception_7 = InceptionModule(176, 160)
        self.inception_8 = InceptionModule(176, 160)


        self.avg_pool = AveragePooling2D((7, 7))


        self.flat = Flatten()
        self.classifier = Dense(num_classes, activation='softmax')


    def call(self, input_tensor, training=False):
        # Forward pass
        x = self.conv_block(input_tensor)
        
        x = self.inception_1(x)
        x = self.inception_2(x)
        x = self.downsample_1(x)


        x = self.inception_3(x)
        x = self.inception_4(x)
        x = self.inception_5(x)
        x = self.inception_6(x)
        x = self.downsample_2(x)


        x = self.inception_7(x)
        x = self.inception_8(x)
        x = self.avg_pool(x)


        x = self.flat(x)
        return self.classifier(x)


    def build_graph(self, input_shape):
        x = Input(shape=input_shape)
        return Model(inputs=[x], outputs=self.call(x))
    
    def get_config(self):
        config = super().get_config().copy()
        config.update({
            'num_classes': self.num_classes
        })
        return config        
No alt text provided for this image
No alt text provided for this image

Load and pre-process CIFAR-10 dataset


import matplotlib.pyplot as pl


(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()


print(x_train.shape, y_train.shape)
print(x_test.shape, y_test.shape)


x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.


# target class name
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']


plt.figure(figsize=(10, 10))
for i in range(25):
    plt.subplot(5, 5, i+1)
    plt.imshow(x_train[i])
    plt.title(class_names[y_train[i][0]])
    plt.axis('off')        
No alt text provided for this image
No alt text provided for this image

Encode the labels to one-hot representations and prepare datasets

We have an input shape of 32x32x3 and total of 10 classes to classify.

# One-hot representation of the output
y_train = keras.utils.to_categorical(y_train, num_classes=10)
y_test = keras.utils.to_categorical(y_test, num_classes=10)
batch_size = 64
# Prepare datasets
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(buffer_size=1024).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)        

Training mechanism

We have two options to train and evaluate our networks in tf.keras. In a convenient way we can train and evaluate by fit and evaluate functions, but in other way we can leverage the low-level over the training and evaluation processes. In the latter, it's required to implement custom training and evaluation loops from scratch. Here are the steps:

  1. using a for loop iterate over epochs
  2. for each epoch, iterate over all batches of dataset using another for loop
  3. for each batch call a train and validation step

Inside train step according to its name we have implemented training algorithm for each batch of dataset, i.e. it is the basis of our training function. GradientTape API provides automatic and convenient differentiation in Tensorflow. The API allows you compute some function outputs gradient with respect to some inputs effortlessly. Inside the GradientTape() scope, output (which basically would be model output and loss) are measured and outside the scope we retrieve the gradient of the model loss with respect to the model trainable weights. This is what we have done in our train_step function. For more information you can see this page. Here is the train_step and test_step function:

@tf.function
def train_step(x, y, model, loss_object, optimizer, train_metric):
    """train step for each dataset batch.
    input: x, y typically batches
    step: batch step
    return: loss value for each batch
    """


    # start the scope of gradient
    with tf.GradientTape() as tape:
        logits = model(x, training=True)    # forward pass
        loss_value = loss_object(y, logits)     # compute loss
    
    # Compute gradients
    grads = tape.gradient(loss_value, model.trainable_weights)
    # update wweights
    optimizer.apply_gradients(zip(grads, model.trainable_weights))


    # update metric
    if train_metric is not None:
        train_metric.update_state(y, logits)
    
    return loss_value



@tf.function
def test_step(x, y, model, loss_fn, test_metric):
    """Evaluate the model for each batch step
    inputs: x, y typically batches
    step: batch step
    model: the NN model
    loss_fn: the loss function
    test_metric: test metric e.g. Accuracy
    writer: tensorboard file writer
    return: loss value
    """
    val_logits = model(x, training=False)


    # Compute loss value
    loss_value = loss_fn(y, val_logits)


    # update metrics
    if test_metric is not None:
        test_metric.update_state(y, val_logits)


    return loss_value
        

Eager execution is turned on by default in Tensorflow 2. tf.function makes a python-independent dataflow graph out of python code which helps you create performant models. but it has side effects on tf.Variable() too. It only supports singleton tf.Variables. When you defined your train_step function using tf.function take care not to call it with two or more different optimizer or network, because optimizer creates tf.Variable inside tf.function, thus you will get the following error. In this page you can find detailed information about tf.function.

No alt text provided for this image

In the next step you can define checkpoint and callbacks for your training algorithm if needed, but it is optional and ignorable.

Checkpoint

tip: If you notice we have defined using tf.train.Checkpoint class not using tf.keras.callbacks.ModelCheckpoint, since we wouldn't train by fit() method.

# Checkpoin
ckpt = tf.train.Checkpoint(step=tf.Variable(1), net=model, 
                           optimizer=optimizer, 
                           iterator=iter(train_dataset))
# Checkpoint manager
manager = tf.train.CheckpointManager(ckpt, './tf_ckpt', max_to_keep=1)
        

We have used the following code snippet to check if there is a checkpoint to restore before iterate through epochs.

 # restore latest checkpoint if an
 checkpoint.restore(ckpt_manager.latest_checkpoint)
 if ckpt_manager.latest_checkpoint:
     print("Loaded from {}".format(ckpt_manager.latest_checkpoint))
 else:
     print("Initialize from scratch.")        

If you would like to define a custom callback in your training process, you must define a subclass inherited from base class of tf.keras.callbacks.Callback. Callback class as like of all other classes has a constructor as __init__(), moreover it has lots of other methods that you should override which one as you need. Some of Callback methods are:

  • on_train_begin(self, logs=None)
  • on_train_end(self, logs=None)
  • on_epoch_begin(self, epoch, logs=None)
  • on_epoch_end(self, epoch, logs=None)
  • on_train_batch_begin(self, batch, logs=None)
  • ...

In the following we have implemented two custom callbacks for custom learning rate schedule and custom early stopping.


class CustomLRScheduler(Callback)
    """Learning rate Scheduler which set learning rate according to schedule.
    Arguments: 
        schedule: a function that takes an epoch index and current learning rate
        as inputs and returns a new learning rate as output.
        opt_lr: optimizer because we cannot reach self.model.optimizer
        when training model using custom training loop.
    """


    def __init__(self,  schedule, opt=None):
        super(CustomLRScheduler, self).__init__()
        self.schedule = schedule
        self.optimizer = opt


    def on_epoch_begin(self, epoch, logs=None):
        if not hasattr(self.optimizer, 'lr'):
            raise ValueError("Optimizer must have a `lr` attribute. ")
        lr = float(keras.backend.get_value(self.optimizer.lr))
        scheduled_lr = self.schedule(epoch, lr)
        keras.backend.set_value(self.optimizer.lr, scheduled_lr)
        print("\nEpoch: %05d learning rate is %6.4f" % (epoch, scheduled_lr))


LR_SCHEDULE = [
    (1, 0.01),
    (10, 0.005),
    (15, 0.001),
    (20, 5e-4),
]


def custom_schedule(epoch, lr):
    """Helper function to return scheduled learning rate based on epoch."""
    
    for i in range(len(LR_SCHEDULE)):
        if epoch == LR_SCHEDULE[i][0]:
            return LR_SCHEDULE[i][1]
    return lr


class CustomEarlyStopping(Callback):
    """Stop training when monitor is at best.


    Arguments:
        model: model under train because self.model is not available 
        through custom training loop
        patience: the number of epochs training can wait since get worse
        monitor: parameter under monitoring
        """


    def __init__(self, model, patience=0, monitor='loss'):
        super(CustomEarlyStopping, self).__init__()
        self.model = model
        self.patience = patience
        # best weights to store the weights at minimum loss score
        self.best_weights = None
        # monitor the parameter to stop early
        self.monitor = monitor


    def on_train_begin(self, logs=None):
        # the number of epoch it has waited when loss is no longer minimum
        self.wait = 0
        # the epoch the training stops at
        self.stopped_epoch = 0
        if 'loss' in self.monitor:
            # Initialize the best as infinity
            self.best = np.Inf
            # set monitor operation
            self.monitor_op = np.less
        else:
            self.best = -np.Inf
            self.monitor_op = np.greater
    
    def on_epoch_end(self, epoch, logs=None):
        current = logs.get(self.monitor)
        if self.monitor_op(current, self.best):
            self.best = current
            self.wait = 0
            # Record the best waits if current results is better
            self.best_weights = self.model.get_weights()
        
        else:
            self.wait += 1
            if self.wait >= self.patience:
                self.stopped_epoch = epoch
                self.model.stop_training = True
                print('Restore model weights from the end of the best epoch.')
                self.model.set_weights(self.best_weights)


    def on_train_end(self, logs=None):
        if self.stopped_epoch > 0:
            print("Epoch %05d early stopping" % (self.stopped_epoch))


# Create callback objects
schedule_learning_rate = CustomLRScheduler(custom_schedule, optimizer)
stop = CustomEarlyStopping(model, patience=3, monitor='val_loss')
        


You can create a callback list using the code below and then apply all callback method to all callbacks with just one line of code.

callbacks = keras.callbacks.CallbackList(_callbacks, add_history=True, model=model)        

The important tip here is where to call which method. If you pay attention to the names of methods you quickly find out where you have to call the method, e.g. you must call the on_train_begin method before training starts in other words exactly before iterating over epochs or you must call the on_train_end method after epoch loop terminates. You should call on_epoch_begin method in the beginning of the epoch, right after the loop over epoch and so on. The code snippet bellow demonstrates it.


callbacks =  tf.keras.callbacks.CallbackList([...]
callbacks.append(...)

callbacks.on_train_begin(...)
for epoch in range(EPOCHS):
  callbacks.on_epoch_begin(epoch)
  for i, data in dataset.enumerate():
    callbacks.on_train_batch_begin(i)
    batch_logs = model.train_step(data)
    callbacks.on_train_batch_end(i, batch_logs)
  epoch_logs = ...
  callbacks.on_epoch_end(epoch, epoch_logs)
final_logs=...
callbacks.on_train_end(final_logs))        


That's our main train function implemented from scratch. You can see all together as follows:

def train(train_dataset, max_epoch, model, loss_object,
          optimizer, validation_dataset=None, train_metric=None, 
          test_metric=None, **kwargs):
    """Custom training loop from scratch."""


    # Initialize 
    best = np.Inf
    save_best_only = False
    save_freq = None


    # define train and validation loss mean
    train_loss_mean = keras.metrics.Mean()
    val_loss_mean = keras.metrics.Mean()
    # print = for each step
    if len(train_dataset) > 30:
        progress_step = np.ceil(len(train_dataset) / 30)
    else:
        progress_step = 1


    # Check if model checkpoint exists
    if 'model_checkpoint' in kwargs:
        model_checkpoint = kwargs.pop('model_checkpoint')
        if 'checkpoint' in model_checkpoint:
            checkpoint = model_checkpoint.pop('checkpoint')
        else:
            checkpoint = tf.train.Checkpoint(step=tf.Variable(1), net=model, optimizer=optimizer, iterator=iter(train_dataset))
        if 'manager' in model_checkpoint:
            ckpt_manager = model_checkpoint.pop('manager')
        else:
            ckpt_manager = tf.train.CheckpointManager(checkpoint, './tf_ckpt', max_to_keep=3)
        if 'monitor' in model_checkpoint:
            if 'loss' in model_checkpoint['monitor']:
                best *= 1
            else:
                best *= -1
        else:
            model_checkpoint['monitor'] = 'val_loss'


        if not 'mode' in model_checkpoint or model_checkpoint['mode'] == 'auto':
            if 'loss' in model_checkpoint['monitor']:
                model_checkpoint.update({'mode': 'min', 'monitor_op': np.less})
            else:
                model_checkpoint.update({'mode': 'max', 'monitor_op': np.greater})


        if not 'save_best_only' in model_checkpoint:
            model_checkpoint.update({'save_best_only': False})


        save_best_only = model_checkpoint.pop('save_best_only')


        if not 'save_freq' in model_checkpoint:
            model_checkpoint['save_freq'] = 'epoch'
        save_freq = model_checkpoint.get('save_freq')


        # restore latest checkpoint if any
        checkpoint.restore(ckpt_manager.latest_checkpoint)
        if ckpt_manager.latest_checkpoint:
            print("Loaded from {}".format(ckpt_manager.latest_checkpoint))
        else:
            print("Initialize from scratch.")
    
    # check if tensorboard file writer exist
    if 'train_writer' in kwargs:
        train_writer = kwargs.pop('train_writer')
    else:
        train_writer = None


    if 'test_writer' in kwargs:
        test_writer = kwargs.pop('test_writer')
    else:
        test_writer = None


    if 'callbacks' in kwargs:
        _callbacks = kwargs.pop('callbacks')
        callbacks = keras.callbacks.CallbackList(_callbacks, add_history=True, model=model)
    else:
        callbacks = keras.callbacks.CallbackList([])
    # Call callbacks when train begins
    log = {}
    callbacks.on_train_begin(logs=log)


    # Main train loop
    for epoch in range(1, max_epoch+1):
        # Call callbacks when epoch begin
        callbacks.on_epoch_begin(epoch, logs=log)
        start = time.time()
        
        print(f"Epoch:{epoch} [>", end='')
        # Iterate over the batchs of train dataset
        for  step, (x_train_batch, y_train_batch) in enumerate(train_dataset):
            # step = tf.convert_to_tensor(step, dtype=tf.int64)
            train_loss_value = train_step(x_train_batch, y_train_batch, 
                                          model, loss_object, optimizer, 
                                          train_metric)
            train_loss_mean.update_state(train_loss_value)
            
            if save_freq is not None and isinstance(save_freq, int):
                checkpoint.step.assign_add(1)
                if int(checkpoint.step) % save_freq == 0:
                    save_path = ckpt_manager.save()
                    print("Saved checkpoint at step {}: {}".format(checkpoint.step, save_path))


            elif (step + 1) % progress_step == 0:
                print('\b=>', end='')


        # write train loss and accuracy to the tensorboard
        if train_writer is not None:
            with train_writer.as_default():
                tf.summary.scalar('loss', train_loss_mean.result(), step=epoch)
                if train_metric is not None:
                    tf.summary.scalar('accuracy', train_metric.result(), step=epoch)
                
        # evaluation
        for x_val_batch, y_val_batch in test_dataset:
            val_loss_value = test_step(x_val_batch, y_val_batch, model, 
                                       loss_object, test_metric)
            val_loss_mean(val_loss_value)


        # write test loss and accuracy to the tensorboard
        if test_writer is not None:
            with test_writer.as_default():
                tf.summary.scalar('val_loss', val_loss_mean.result(), step=epoch)
                if test_metric is not None:
                    tf.summary.scalar('val_accuracy', test_metric.result(), step=epoch)
        
        template = "\b] ETA: {:.2f}, loss: {:.4f}, acc: {:.4f} | val loss: {:.4f}, val acc: {:.4f}"


        print(template.format(
            (time.time() - start)/60, train_loss_mean.result(), 
            train_metric.result(), val_loss_mean.result(), 
            test_metric.result()
        ))


        if model_checkpoint['monitor'] == 'loss':
            current = train_loss_mean.result()
        elif model_checkpoint['monitor'] == 'val_loss':
            current = val_loss_mean.result()
        elif model_checkpoint['monitor'] == 'acc' or model_checkpoint['monitor'] == 'accuracy':
            current = train_metric.result()
        elif model_checkpoint['monitor'] == 'val_acc' or model_checkpoint['monitor'] == 'val_accuracy':
            current = test_metric.result()
        
        if save_freq is not None and save_freq == 'epoch':
            checkpoint.step.assign_add(1)
            # save just best model
            if save_best_only:
                if model_checkpoint.get('monitor_op')(current, best):
                    save_path = ckpt_manager.save()
                    print("Improved {} from {} to {}".format(model_checkpoint.get('monitor'), best, current))
                    print("\t\tSaved checkpoint: {}".format(checkpoint.step, save_path))                
                    best = current
                # model.save("best-{}".format(model.__class__.__name__), 
                #         save_format='tf')
                # best_loss = val_loss_mean.result()
                # print("model saved")
            else:
                save_path = ckpt_manager.save()
                print("Saved checkpoint: {}\n".format(save_path))


        # Call callbacks when epoch ends
        log = dict(
            loss=train_loss_mean.result(),
            val_loss=val_loss_mean.result(),
            acc=train_metric.result(),
            val_acc=test_metric.result()
        )
        callbacks.on_epoch_end(epoch, logs=log)


        # Reset metrics at the end of each epoch
        train_metric.reset_states()
        test_metric.reset_states()
        train_loss_mean.reset_states()
        val_loss_mean.reset_states()
    
    # Call callbacks when train end
    callbacks.on_train_end(logs={})          

At last, call train function by defined checkpoint and callbacks:


model_checkpoint = 
    'checkpoint': ckpt, 'manager': manager, 'monitor': 'val_loss', 
    'mode': 'min', 'save_freq': 'epoch', 'save_best_only': False
}


epochs = 30
train(train_dataset, epochs, model, loss_object, optimizer, test_dataset, train_metric,
      test_metric, train_writer=train_writer, test_writer=test_writer,
      model_checkpoint=model_checkpoint, 
      callbacks=[schedule_learning_rate, stop]
      )        
No alt text provided for this image

If you rerun the train function it will strat from the last saved checkpoint, not from scratch.

We have defined file writer using tf.summary in order to save loss and accuracy values to represent later by Tensorboard.


# tensorboard write
train_writer = tf.summary.create_file_writer('logs/train/')
test_writer = tf.summary.create_file_writer('logs/test/')
        

We have saved the loss and accuracy values for train and evaluation by the code snippet as below:


for  step, (x_train_batch, y_train_batch) in enumerate(train_dataset):
            # step = tf.convert_to_tensor(step, dtype=tf.int64)
     train_loss_value = train_step(x_train_batch, y_train_batch, 
                                          model, loss_object, optimizer, 
                                          train_metric)
     train_loss_mean.update_state(train_loss_value)
            
            
# write train loss and accuracy to the tensorboard
with train_writer.as_default():
    tf.summary.scalar('loss', train_loss_mean.result(), step=epoch)
        tf.summary.scalar('accuracy', train_metric.result(), step=epoch)
                
        


for x_val_batch, y_val_batch in test_dataset:
    val_loss_value = test_step(x_val_batch, y_val_batch, model, 
                                       loss_object, test_metric)
    val_loss_mean(val_loss_value)


# write test loss and accuracy to the tensorboard
with test_writer.as_default():
    tf.summary.scalar('val_loss', val_loss_mean.result(), step=epoch)
    tf.summary.scalar('val_accuracy', test_metric.result(), step=epoch)
                

Afterwards, we can use Tensorboard to display loss and accuracy values during training and validation process.


%load_ext tensorboar
%tensorboard --logdir logs        
No alt text provided for this image

Last but not least, you cannot save the whole model neither in SavedModel nor HDF5 format if your model implemented by sub-classing API. You can save the model entire if your model defined by Sequential or Functional APIs, otherwise we can save just model weights. Therefore we can reload the model if only the code is available.


# save just model weights
model.save_weights(f"./model/{model.__class__.__name__}-weights")        

Reloading a model designed using a subclass API is more complicated. The solution was found in an awesome article written by François Chollet regarding saving and serializing models in TF.Keras that must read. It has three stages as follows:

  • Create an instance of the sub-classed model
  • In order to restore the optimizer state and any stateful metric you should compile the model with the same exact arguments as before
  • Call the model on some data before calling load_weights


# Restore model i.e model weights
loaded = MiniInception()
loaded.compile(
    loss=keras.losses.CategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(),
    metrics=keras.metrics.CategoricalAccuracy()
)
x_batch, y_batch = next(iter(train_dataset))
print(x_batch.shape, y_batch.shape)
loaded.train_on_batch(x_batch, y_batch)
loaded.load_weights("./model/{}-weights".format(loaded.__class__.__name__))        

Conclusion

In this tutorial we have explained how to implement a sub-classed model and train it using low-level custom training loop. Moreover, we have illustrated how to implement custom checkpoint and callbacks for such training mechanism. We tried to put all custom arguments together and save and reload this fully customized model excluding loss function. To customize loss you can define a custom loss function or define a sub-class extending keras.losses.Loss class. I hope to do that in next parts. Keep in mind that never try to save a whole model while defined it as a sub-class, because Sequential and Functional models are data structures, but a sub-classed model is a piece of code and to reload it you need the code that created it.

REFERENCES:

https://towardsdatascience.com/model-sub-classing-and-custom-training-loop-from-scratch-in-tensorflow-2-cc1d4f10fb4

https://www.tensorflow.org/guide/checkpoint

https://www.tensorflow.org/guide/keras/custom_callback

https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback

https://medium.com/smart-iot/custom-training-with-custom-callbacks-3bcd117a8f7e

https://colab.research.google.com/drive/172D4jishSgE3N7AO6U2OKAA_0wNnrMOq#scrollTo=KOKNBojtsl0F

To view or add a comment, sign in

More articles by Arash Dehghniyan Serej

Others also viewed

Explore content categories