from tensorflow.python.keras.utils import losses_utils

kl = tf.keras.losses.KLDivergence(
reduction = losses_utils.ReductionV2.NONE,
name = 'kullback_leibler_divergence')

kl_loss = tf.reduce_mean(kl(logit1, logit2))