scaler = GradScaler()
for features, target in data:
    # Forward pass with mixed precision
    with torch.cuda.amp.autocast(): # autocast as a context manager
        output = model(features)
        loss = criterion(output, target)    
    
    # Backward pass without mixed precision
    # It's not recommended to use mixed precision for backward pass
    # Because we need more precise loss
    scaler.scale(loss).backward()    
    
    # scaler.step() first unscales the gradients .
    # If these gradients contain infs or NaNs, 
    # optimizer.step() is skipped.
    scaler.step(optimizer)     
    
    # If optimizer.step() was skipped,
    # scaling factor is reduced by the backoff_factor in GradScaler()
    scaler.update()

autocast自动应用精度到不同的操作。因为损失和梯度是按照float16精度计算的,当它们太小时,梯度可能会“下溢”并变成零。
GradScaler通过将损失乘以一个比例因子来防止下溢,根据比例损失计算梯度,然后在优化器更新权重之前取消梯度的比例。如果缩放因子太大或太小,并导致inf或NaN,则缩放因子将在下一个迭代中更新缩放因子。