Chapter 4 Model Usage
Since the transformer model constructed here conforms to the Keras API guidelines, we can naturally use the built-in APIs for training and inference.
4.1 Training
4.1.1 Loss
One of the trickiest aspects of the implementation was getting the loss right. This is one place where the disadvantages of using the higher-level Keras API show up: Less control and less clarity about what is going on behind the scenes. It took some time to notice that losses compiled into the Keras model use the propagated mask to modify the loss calculation as wanted, but do not make the expected/desired reduction afterwards. We expected that simply compiling the built-in SparseCategoricalCrossentropy loss into the model would give the correct loss. The compiled losses use the mask on the model output to correctly mask out the losses for irrelevant sequence members, i.e. it zeros the losses corresponding to sequence padding; however, the average is then computed over the entire sequence.
For example, if a batch has dimension (64, 37), then while the 64 * 37 loss matrix will have 0s where there is padding, the final loss is calculated by summing the loss matrix and then calculating the mean by dividing by 64 * 37. However, to correctly calculate the summarized loss we want to divide by the number of non-masked elements in the batch. While the transformer still learns reasonably well with this built-in loss calculation, is does significantly better with the correct loss.
We could not see anyway to opt out of this behaviour, short of removing the mask from the final output which is a hack and causes the built-in metrics to give incorrect results. To overcome this we added the following “correction factor” to a custom loss, which is also a hack. From transformer/loss.py:
from tensorflow.keras.losses import Loss, sparse_categorical_crossentropy
class MaskedSparseCategoricalCrossentropy(Loss):
def __init__(self, name='masked_sparse_categorical_cross_entropy'):
super().__init__(name=name)
def call(self, y_true, y_pred):
loss = sparse_categorical_crossentropy(y_true, y_pred,
from_logits=True)
mask = getattr(y_pred, '_keras_mask')
sw = tf.cast(mask, y_pred.dtype)
# desired loss value
reduced_loss = tf.reduce_sum(loss * sw) / tf.reduce_sum(sw)
# cannot opt out of mask corrections in the API
correction_factor = tf.reduce_sum(tf.ones(shape=tf.shape(y_true))) / \
tf.reduce_sum(sw)
return reduced_loss * correction_factor4.1.2 Optimization
transformer/schedule.py
from tensorflow.keras.optimizers.schedules import LearningRateSchedule
class CustomSchedule(LearningRateSchedule):
def __init__(self, d_model, warmup_steps=4000):
super(CustomSchedule, self).__init__()
self.d_model = tf.cast(d_model, tf.float32)
self.warmup_steps = warmup_steps
def __call__(self, step):
arg1 = tf.math.rsqrt(step)
arg2 = step * (self.warmup_steps ** -1.5)
return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)The Adam optimizer is used with the same settings as in the paper Vaswani et al. (2017).
from tensorflow.keras.optimizers import Adam
D_MODEL = 128
learning_rate = CustomSchedule(d_model=D_MODEL)
optimizer = Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)4.1.3 Learning
The actual code is in program.py. However, the following sequence illustrates how the Keras training API
is called.
from transformer.transformer import transformer
model = transformer(num_layers=4, d_model=D_MODEL,
num_heads=8, dff=512,
input_vocab_size=input_vocab_size,
target_vocab_size=target_vocab_size,
pe_input_max=MAX_LEN,
pe_target_max=MAX_LEN,
dropout_rate=0.1)
model.compile(optimizer=optimizer,
loss=MaskedSparseCategoricalCrossentropy(),
metrics=['accuracy'])
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
TRAIN_DIR + '/checkpoint.{epoch}.ckpt',
save_weights_only=True)
model.fit(data_train, epochs=1, validation_data=data_eval,
callbacks=model_checkpoint_callback)4.2 Inference
Inference with the transformer, or any auto-regressive model, is not simply a matter of plugging a testing pipeline into the model and calling predict. The training process uses teacher forcing as previously discussed, which means the next symbol is predicted based on the given ground truth up to that point in the sequence. In contrast, during inference the sequence of predicted symbols is used to recursively predict the next symbol. The code for doing this is in transformer/autoregression.py:
def autoregress(model, input, delimiters, max_length):
delimiters = delimiters[0]
decoder_input = [delimiters[0]]
output = tf.expand_dims(decoder_input, 0)
done = False
while not done:
preds = model({'encoder_input': tf.expand_dims(input, 0), 'decoder_input': output})
prediction = preds[:, -1, :]
pred_id = tf.argmax(prediction, axis=-1) \
if tf.shape(output)[1] < max_length - 1 else tf.expand_dims(delimiters[1], 0)
done = pred_id == delimiters[1]
output = tf.concat([output, tf.expand_dims(pred_id, 0)], axis=-1)
return tf.squeeze(output, axis=0)
def translate(model, input, tokenizers, max_length):
"""
Translate an input sentence to a target sentence using a model
"""
input_encoded = tokenizers.inputs.tokenize([input])[0]
if len(input_encoded) > max_length:
return None
prediction = autoregress(model,
input_encoded,
delimiters=tokenizers.targets.tokenize(['']),
max_length=max_length)
prediction_decoded = tokenizers.targets.detokenize([prediction]).numpy()[0][0].decode('utf-8')
return prediction_decoded