Chapter 4 Model Usage

Since the transformer model constructed here conforms to the Keras API guidelines, we can naturally use the built-in APIs for training and inference.

4.1 Training

4.1.1 Loss

One of the trickiest aspects of the implementation was getting the loss right. This is one place where the disadvantages of using the higher-level Keras API show up: Less control and less clarity about what is going on behind the scenes. It took some time to notice that losses compiled into the Keras model use the propagated mask to modify the loss calculation as wanted, but do not make the expected/desired reduction afterwards. We expected that simply compiling the built-in SparseCategoricalCrossentropy loss into the model would give the correct loss. The compiled losses use the mask on the model output to correctly mask out the losses for irrelevant sequence members, i.e. it zeros the losses corresponding to sequence padding; however, the average is then computed over the entire sequence. For example, if a batch has dimension (64, 37), then while the 64 * 37 loss matrix will have 0s where there is padding, the final loss is calculated by summing the loss matrix and then calculating the mean by dividing by 64 * 37. However, to correctly calculate the summarized loss we want to divide by the number of non-masked elements in the batch. While the transformer still learns reasonably well with this built-in loss calculation, is does significantly better with the correct loss.

We could not see anyway to opt out of this behaviour, short of removing the mask from the final output which is a hack and causes the built-in metrics to give incorrect results. To overcome this we added the following “correction factor” to a custom loss, which is also a hack. From transformer/loss.py:

from tensorflow.keras.losses import Loss, sparse_categorical_crossentropy


class MaskedSparseCategoricalCrossentropy(Loss):
    def __init__(self, name='masked_sparse_categorical_cross_entropy'):
        super().__init__(name=name)

    def call(self, y_true, y_pred):
        loss = sparse_categorical_crossentropy(y_true, y_pred,
                                               from_logits=True)
        mask = getattr(y_pred, '_keras_mask')
        sw = tf.cast(mask, y_pred.dtype)
        # desired loss value
        reduced_loss = tf.reduce_sum(loss * sw) / tf.reduce_sum(sw)
        # cannot opt out of mask corrections in the API
        correction_factor = tf.reduce_sum(tf.ones(shape=tf.shape(y_true))) / \
                            tf.reduce_sum(sw)

        return reduced_loss * correction_factor

4.1.2 Optimization

transformer/schedule.py

from tensorflow.keras.optimizers.schedules import LearningRateSchedule


class CustomSchedule(LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=4000):
        super(CustomSchedule, self).__init__()
        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

The Adam optimizer is used with the same settings as in the paper Vaswani et al. (2017).

from tensorflow.keras.optimizers import Adam

D_MODEL = 128

learning_rate = CustomSchedule(d_model=D_MODEL)
optimizer = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

4.1.3 Learning

The actual code is in program.py. However, the following sequence illustrates how the Keras training API is called.

from transformer.transformer import transformer

model = transformer(num_layers=4, d_model=D_MODEL, 
                    num_heads=8, dff=512,
                    input_vocab_size=input_vocab_size,
                    target_vocab_size=target_vocab_size,
                    pe_input_max=MAX_LEN,
                    pe_target_max=MAX_LEN,
                    dropout_rate=0.1)
                    
model.compile(optimizer=optimizer,
              loss=MaskedSparseCategoricalCrossentropy(),
              metrics=['accuracy'])
              
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    TRAIN_DIR + '/checkpoint.{epoch}.ckpt',
    save_weights_only=True)

model.fit(data_train, epochs=1, validation_data=data_eval,   
          callbacks=model_checkpoint_callback)

4.2 Inference

Inference with the transformer, or any auto-regressive model, is not simply a matter of plugging a testing pipeline into the model and calling predict. The training process uses teacher forcing as previously discussed, which means the next symbol is predicted based on the given ground truth up to that point in the sequence. In contrast, during inference the sequence of predicted symbols is used to recursively predict the next symbol. The code for doing this is in transformer/autoregression.py:

def autoregress(model, input, delimiters, max_length):
    delimiters = delimiters[0]
    decoder_input = [delimiters[0]]
        
    output = tf.expand_dims(decoder_input, 0)

    done = False
    while not done:
        preds = model({'encoder_input': tf.expand_dims(input, 0), 'decoder_input': output})
        prediction = preds[:, -1, :]
        pred_id = tf.argmax(prediction, axis=-1) \
            if tf.shape(output)[1] < max_length - 1 else tf.expand_dims(delimiters[1], 0)

        done = pred_id == delimiters[1]
        output = tf.concat([output, tf.expand_dims(pred_id, 0)], axis=-1)

    return tf.squeeze(output, axis=0)


def translate(model, input, tokenizers, max_length):
    """
    Translate an input sentence to a target sentence using a model
    """
    input_encoded = tokenizers.inputs.tokenize([input])[0]

    if len(input_encoded) > max_length:
        return None

    prediction = autoregress(model, 
                             input_encoded, 
                             delimiters=tokenizers.targets.tokenize(['']),
                             max_length=max_length)
    prediction_decoded = tokenizers.targets.detokenize([prediction]).numpy()[0][0].decode('utf-8')

    return prediction_decoded

References

Vaswani, Ashish, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. 2017. “Attention Is All You Need.” CoRR abs/1706.03762. http://arxiv.org/abs/1706.03762.