def adversarial_loss(embedded, loss, loss_fn):
"""Adds gradient to embedding and recomputes classification loss."""
grad, = tf.gradients(
loss,
embedded,
aggregation_method=tf.AggregationMethod.EXPERIMENTAL_ACCUMULATE_N)
grad = tf.stop_gradient(grad)
perturb = _scale_l2(grad, FLAGS.perturb_norm_length)
return loss_fn(embedded + perturb)