一、apex

是什么:混合精度

什么用:提升GPU上的训练速度

GitHub:https://github.com/NVIDIA/apex

API文档:https://nvidia.github.io/apex 

使用要求:

Python 3

CUDA 9 or newer

PyTorch 0.4 or newer. The CUDA and C++ extensions require pytorch 1.0 or newer.

推荐已发布的最新版本,见https://pytorch.org/.

我们也针对最新的主分支进行测试, obtainable from https://github.com/pytorch/pytorch.

在Docker容器中使用Apex通常很方便。兼容的选项包括:

NVIDIA Pytorch containers from NGC, which come with Apex preinstalled. To use the latest Amp API, you may need to pip uninstall apex then reinstall Apex using the Quick Start commands below.
official Pytorch -devel Dockerfiles, e.g. docker pull pytorch/pytorch:nightly-devel-cuda10.0-cudnn7, in which you can install Apex using the Quick Start commands.

如何安装:
Linux:

为了性能和完整的功能,建议通过CUDA和c++扩展来安装Apex

$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Apex 同样支持 Python-only build (required with Pytorch 0.4) via

$ pip install -v --no-cache-dir ./

Windows:

Windows支持是实验性的,建议使用Linux。

如果你能在你的系统上从源代码构建Pytorch,采用pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .

pip install -v --no-cache-dir .(没有CUDA/ c++扩展)更可能有效。

如果您已经在Conda环境中安装了Pytorch,请确保在相同的环境中安装Apex。

相关链接:https://github.com/NVIDIA/apex/tree/master/examples/docker

安装后如何使用:参考文档https://nvidia.github.io/apex/amp.html

例子:

# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Allow Amp to perform casts as required by the opt_level
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
...
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()

 

二、我的安装流程:

1. $ git clone https://github.com/NVIDIA/apex 完成
2. $ cd apex 完成
3. $ pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

3时出现报错,这个问题issue上有很多人在问

Cleaning up...
  Removing source in /tmp/pip-req-build-v0deounv
Removed build tracker '/tmp/pip-req-tracker-3n3fyj4o'
ERROR: Command errored out with exit status 1: /users4/zsun/anaconda3/bin/python -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-v0deounv/setup.py'"'"'; __file__='"'"'/tmp/p
ip-req-build-v0deounv/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' --cpp_e
xt --cuda_ext install --record /tmp/pip-record-rce1cb4d/install-record.txt --single-version-externally-managed --compile Check the logs for full command output.
Exception information:
Traceback (most recent call last):
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/cli/base_command.py", line 153, in _main
    status = self.run(options, args)
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/commands/install.py", line 455, in run
    use_user_site=options.use_user_site,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/req/__init__.py", line 62, in install_given_reqs
    **kwargs
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/req/req_install.py", line 888, in install
    cwd=self.unpacked_source_directory,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/utils/subprocess.py", line 275, in runner
    spinner=spinner,
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/pip/_internal/utils/subprocess.py", line 242, in call_subprocess
    raise InstallationError(exc_msg)
pip._internal.exceptions.InstallationError: Command errored out with exit status 1: /users4/zsun/anaconda3/bin/python -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-v0deoun
v/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-v0deounv/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code
, __file__, '"'"'exec'"'"'))' --cpp_ext --cuda_ext install --record /tmp/pip-record-rce1cb4d/install-record.txt --single-version-externally-managed --compile Check the logs for full command output.
1 location(s) to search for versions of pip:
* http://mirrors.aliyun.com/pypi/simple/pip/
Getting page http://mirrors.aliyun.com/pypi/simple/pip/
Found index url http://mirrors.aliyun.com/pypi/simple/
Starting new HTTP connection (1): mirrors.aliyun.com:80
http://mirrors.aliyun.com:80 "GET /pypi/simple/pip/ HTTP/1.1" 200 12139
Analyzing links from page http://mirrors.aliyun.com/pypi/simple/pip/
  Found link http://mirrors.aliyun.com/pypi/packages/18/ad/c0fe6cdfe1643a19ef027c7168572dac6283b80a384ddf21b75b921877da/pip-0.2.1.tar.gz#sha256=83522005c1266cc2de97e65072ff7554ac0f30ad369c3b02ff3a764b9620
48da (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2.1
  Found link http://mirrors.aliyun.com/pypi/packages/3d/9d/1e313763bdfb6a48977b65829c6ce2a43eaae29ea2f907c8bbef024a7219/pip-0.2.tar.gz#sha256=88bb8d029e1bf4acd0e04d300104b7440086f94cc1ce1c5c3c31e3293aee1f
81 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2
  Found link http://mirrors.aliyun.com/pypi/packages/0a/bb/d087c9a1415f8726e683791c0b2943c53f2b76e69f527f2e2b2e9f9e7b5c/pip-0.3.1.tar.gz#sha256=34ce534f17065c78f980702928e988a6b6b2d8a9851aae5f1571a1feb9bb
58d8 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3.1
  Found link http://mirrors.aliyun.com/pypi/packages/17/05/f66144ef69b436d07f8eeeb28b7f77137f80de4bf60349ec6f0f9509e801/pip-0.3.tar.gz#sha256=183c72455cb7f8860ac1376f8c4f14d7f545aeab8ee7c22cd4caf79f35a2ed
47 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3
  Found link http://mirrors.aliyun.com/pypi/packages/cf/c3/153571aaac6cf999f4bb09c019b1ff379b7b599ea833813a41c784eec995/pip-0.4.tar.gz#sha256=28fc67558874f71fddda7168f73595f1650523dce3bc5bf189713ecdfc1e45
6e (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.4
  Found link 



  Found link http://mirrors.aliyun.com/pypi/packages/ac/95/a05b56bb975efa78d3557efa36acaf9cf5d2fd0ee0062060493687432e03/pip-9.0.3-py2.py3-none-any.whl#sha256=c3ede34530e0e0b2381e7363aded78e0c33291654937e7373032fda04e8803e5 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
  Found link http://mirrors.aliyun.com/pypi/packages/c4/44/e6b8056b6c8f2bfd1445cc9990f478930d8e3459e9dbf5b8e2d2922d64d3/pip-9.0.3.tar.gz#sha256=7bf48f9a693be1d58f49f7af7e0ae9fe29fd671cde8a55e6edca3581c4ef5796 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
Given no hashes to check 131 links for project 'pip': discarding no candidates

4. 使用

$ pip install -v --no-cache-dir ./

得到的反馈是

copying build/lib/apex/RNN/RNNBackend.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  copying build/lib/apex/RNN/cells.py -> build/bdist.linux-x86_64/wheel/apex/RNN
  creating build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/__init__.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  copying build/lib/apex/normalization/fused_layer_norm.py -> build/bdist.linux-x86_64/wheel/apex/normalization
  running install_egg_info
  running egg_info
  creating apex.egg-info
  writing apex.egg-info/PKG-INFO
  writing dependency_links to apex.egg-info/dependency_links.txt
  writing top-level names to apex.egg-info/top_level.txt
  writing manifest file 'apex.egg-info/SOURCES.txt'
  reading manifest file 'apex.egg-info/SOURCES.txt'
  writing manifest file 'apex.egg-info/SOURCES.txt'
  Copying apex.egg-info to build/bdist.linux-x86_64/wheel/apex-0.1-py3.6.egg-info
  running install_scripts
  creating build/bdist.linux-x86_64/wheel/apex-0.1.dist-info/WHEEL
done
  Created wheel for apex: filename=apex-0.1-cp36-none-any.whl size=136906 sha256=55830f559061fcb30ed616dd6879086c9b79926c3d3e0017a2dcf6c0e1aa8037
  Stored in directory: /tmp/pip-ephem-wheel-cache-m4cxipvx/wheels/6c/91/1a/143cfe0f99d10c8c415d1594024d1de93c5f8c03f5edfad2ba
  Removing source in /tmp/pip-req-build-yg8bljf6
Successfully built apex
Installing collected packages: apex
  Found existing installation: apex 0.1
    Uninstalling apex-0.1:
      Created temporary directory: /users4/zsun/anaconda3/lib/python3.6/site-packages/~pex-0.1.dist-info
      Removing file or directory /users4/zsun/anaconda3/lib/python3.6/site-packages/apex-0.1.dist-info/
      Created temporary directory: /users4/zsun/anaconda3/lib/python3.6/site-packages/~pex
      Removing file or directory /users4/zsun/anaconda3/lib/python3.6/site-packages/apex/
      Successfully uninstalled apex-0.1

Successfully installed apex-0.1
Cleaning up...
Removed build tracker '/tmp/pip-req-tracker-nm6wywoj'
1 location(s) to search for versions of pip:
* http://mirrors.aliyun.com/pypi/simple/pip/
Getting page http://mirrors.aliyun.com/pypi/simple/pip/
Found index url http://mirrors.aliyun.com/pypi/simple/
Starting new HTTP connection (1): mirrors.aliyun.com:80
http://mirrors.aliyun.com:80 "GET /pypi/simple/pip/ HTTP/1.1" 200 12139
Analyzing links from page http://mirrors.aliyun.com/pypi/simple/pip/
  Found link http://mirrors.aliyun.com/pypi/packages/18/ad/c0fe6cdfe1643a19ef027c7168572dac6283b80a384ddf21b75b921877da/pip-0.2.1.tar.gz#sha256=83522005c1266cc2de97e65072ff7554ac0f30ad369c3b02ff3a764b962048da (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2.1
  Found link http://mirrors.aliyun.com/pypi/packages/3d/9d/1e313763bdfb6a48977b65829c6ce2a43eaae29ea2f907c8bbef024a7219/pip-0.2.tar.gz#sha256=88bb8d029e1bf4acd0e04d300104b7440086f94cc1ce1c5c3c31e3293aee1f81 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.2
  Found link http://mirrors.aliyun.com/pypi/packages/0a/bb/d087c9a1415f8726e683791c0b2943c53f2b76e69f527f2e2b2e9f9e7b5c/pip-0.3.1.tar.gz#sha256=34ce534f17065c78f980702928e988a6b6b2d8a9851aae5f1571a1feb9bb58d8 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3.1
  Found link http://mirrors.aliyun.com/pypi/packages/17/05/f66144ef69b436d07f8eeeb28b7f77137f80de4bf60349ec6f0f9509e801/pip-0.3.tar.gz#sha256=183c72455cb7f8860ac1376f8c4f14d7f545aeab8ee7c22cd4caf79f35a2ed47 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.3
  Found link http://mirrors.aliyun.com/pypi/packages/cf/c3/153571aaac6cf999f4bb09c019b1ff379b7b599ea833813a41c784eec995/pip-0.4.tar.gz#sha256=28fc67558874f71fddda7168f73595f1650523dce3bc5bf189713ecdfc1e456e (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 0.4
  Found link http://mirrors.aliyun.com/pypi/packages/9a/aa/f536b6d14fe03343367da2ff44eee28f340ae650cd017ca088b6be13084a/pip-0.5.1.tar.gz#sha256=e27650538c41fe1007a41abd4cfd0f905b822622cbe1f8e7e09d1215af207694 (from http://mirrors.aliyun.com/pypi/simple/pi





  Found link http://mirrors.aliyun.com/pypi/packages/ac/95/a05b56bb975efa78d3557efa36acaf9cf5d2fd0ee0062060493687432e03/pip-9.0.3-py2.py3-none-any.whl#sha256=c3ede34530e0e0b2381e7363aded78e0c33291654937e7373032fda04e8803e5 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
  Found link http://mirrors.aliyun.com/pypi/packages/c4/44/e6b8056b6c8f2bfd1445cc9990f478930d8e3459e9dbf5b8e2d2922d64d3/pip-9.0.3.tar.gz#sha256=7bf48f9a693be1d58f49f7af7e0ae9fe29fd671cde8a55e6edca3581c4ef5796 (from http://mirrors.aliyun.com/pypi/simple/pip/), version: 9.0.3
Given no hashes to check 131 links for project 'pip': discarding no candidates

虽然最后报了一样的错误,但在中间出现了【Successfully installed apex-0.1】

暂且当做安装成功,继续向下进行。

5. 按照上面的例子以及下面的例子更改我的代码,很简单,只有几处更改。

if args.apex:
    from apex import amp
# Declare model and optimizer as usual, with default (FP32) precision
model = torch.nn.Linear(D_in, D_out).cuda()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

# Allow Amp to perform casts as required by the opt_level
model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
...
# loss.backward() becomes:
with amp.scale_loss(loss, optimizer) as scaled_loss:
    scaled_loss.backward()
...
# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
amp.load_state_dict(checkpoint['amp'])
...

6.运行第一次看效果,得到反馈

2019-11-27 14:42:14,362 INFO: Loading vocab,train and val dataset.Wait a second,please
#Params: 73.7M
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Warning:  multi_tensor_applier fused unscale kernel is unavailable, possibly because apex was installed without --cuda_ext --cpp_ext. Using Python fallback.  Original ImportError was: ModuleNotFoundError("No module named 'amp_C'",)
Traceback (most recent call last):
  File "main.py", line 582, in <module>
    train()
  File "main.py", line 388, in train
    with amp.scale_loss(loss, optimizer) as scaled_loss:
  File "/users4/zsun/anaconda3/lib/python3.6/contextlib.py", line 81, in __enter__
    return next(self.gen)
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/handle.py", line 111, in scale_loss
    optimizer._prepare_amp_backward()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 219, in prepare_backward_no_master_weights
    self._amp_lazy_init()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 309, in _amp_lazy_init
    self._lazy_init_maybe_master_weights()
  File "/users4/zsun/anaconda3/lib/python3.6/site-packages/apex/amp/_process_optimizer.py", line 210, in lazy_init_no_master_weights
    "Received {}".format(param.type()))
TypeError: Optimizer's parameters must be either torch.cuda.FloatTensor or torch.cuda.HalfTensor. Received torch.FloatTensor

这里可以看到前面步骤没有使用C扩展来安装apex还是有一定的问题的,无法使用fused unscale kernel(本程序用不到这个),但是这是warning不是error,所以我们改正错误继续运行。运行成功。

(上面出错的原因是我的optim是放到cpu计算的,但是apex要求他要放到gpu,与本步骤无关)

没加apex之前的程序 batch=8的,<3.5h一次eval,三次eval一轮,4911 / 12196 MB | zsun(4901M)

                                    batch=10的,<2.8h一次eval,三次eval一轮, 9723 / 12196 MB | zsun(9713M)

                                    batch=4的,<7.25h一次eval,三次eval一轮,9863 / 16280 MB | zsun(9853M)(另一份程序|没CVAE的)

                                    batch=16的,会一段时间之后out of memory

加了apex之后的程序 batch=16,<1.5h一次eval,三次eval一轮,11517 / 12196 MB | zsun(11507M),不再oom

7. 但是出现了效果下降的问题。有错误提醒:

2019-11-28 06:44:23,244 INFO: eval
2019-11-28 06:49:59,922 INFO: Epoch:  4 fmax: 0.246901 cur_max_f: 0.183399
2019-11-28 06:49:59,923 INFO: Epoch:  4 Min_Val_Loss: 0.318247 Cur_Val_Loss: 0.427435
2019-11-28 06:49:59,930 INFO:   [0] Cur_fmax: 0.183399 Cur_bound: 0.139370
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.00048828125
2019-11-28 08:17:46,401 INFO: eval
2019-11-28 08:23:15,293 INFO: Epoch:  4 fmax: 0.246901 cur_max_f: 0.215748
2019-11-28 08:23:15,294 INFO: Epoch:  4 Min_Val_Loss: 0.318247 Cur_Val_Loss: 0.474844
2019-11-28 08:23:15,304 INFO:   [0] Cur_fmax: 0.215748 Cur_bound: 0.016865
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.000244140625

而之前我的程序是这样的

2019-11-28 00:03:36,134 INFO: eval
2019-11-28 00:06:15,011 INFO: Epoch:  5 fmax: 0.437465 cur_max_f: 0.437465
2019-11-28 00:06:15,012 INFO: Epoch:  5 Min_Val_Loss: 0.251108 Cur_Val_Loss: 0.251108
2019-11-28 00:06:15,019 INFO:   [0] Cur_fmax: 0.437465 Cur_bound: 0.231224

发现loss整体变大,而且很不稳定。效果变差。而且这个错误提醒是什么意思呢?

Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.00048828125

意思是:梯度溢出,issue上也有很多人提出了这个问题,貌似作者一直在收集这个问题出现的样例,尚未解决。

 

 

 

 

 

 

 

 

三、其余注意事项(均来自于官方文档)

3.1.apex.amp

注意⚠️:

目前,控制纯或混合精度训练的底层特性如下:

cast_model_type:将模型的参数和缓冲区强制转换为所需的类型。
patch_torch_functions:修补所有Torch函数和张量方法,以执行对张量核心友好的操作,比如FP16中的GEMMs和convolutions,以及FP32中任何受益于FP32精度的操作。
keep_batchnorm_fp32:为了提高精度并启用cudnn batchnorm(这可以提高性能),将batchnorm的权重保持在FP32中通常是有益的,即使模型的其余部分是FP16。
master_weights:保持FP32的主权重,以配合任何FP16模型的权重。FP32主权重由优化器逐步提高精度和捕获小梯度。
loss_scale:如果loss_scale是一个浮点值,那么使用这个值作为静态(固定)损失范围。如果loss_scale是字符串“dynamic”,则自适应地随时间调整损失比例。动态损失比例调整由放大器自动执行。

同样,您通常不需要手动指定这些属性。相反,选择一个opt_level,它将为您设置它们。在选择opt_level之后,可以选择将属性kwargs作为手动覆盖传递。

如果您试图覆盖一个属性,这是没有意义的选择opt_level, Amp将提出一个错误的解释。例如,选择opt_level="O1"并使用override master_weights=True是没有意义的。O1插入围绕Torch函数而不是模型权重进行强制转换。数据、激活和权重在它们流经修补过的函数时被动态地重新分配。因此,模型本身的权重可以(也应该)保持FP32,不需要保持单独的FP32主权重。
opt_levels

可识别的opt_levels是“O0”、“O1”、“O2”和“O3”。

O0和O3并不是真正的混合精度,但是它们分别用于建立精度和速度基线。
O1和O2是混合精度的不同实现。试试这两种方法,看看什么能给你的模型带来最好的加速和准确性。

O0: FP32 training
你的incoming model应该已经是FP32了,所以这可能是一个无操作。O0可用于建立准确性基线。

O0设置的默认属性:
cast_model_type = torch.float32
patch_torch_functions = False
keep_batchnorm_fp32=None(实际上,“不适用”,一切都是FP32)
master_weights = False
loss_scale = 1.0

O1: Mixed Precision (recommended for typical use)

对所有Torch函数和张量方法进行修补,使它们的输入符合白名单-黑名单模型。白名单操作(例如,张量核心友好操作,如GEMMs和convolutions)在FP16中执行。受益于FP32精度的黑名单操作(例如softmax)在FP32中执行。O1还使用动态损失缩放,除非覆盖。

O1设置的默认属性:
cast_model_type=None (not applicable)
patch_torch_functions=True
keep_batchnorm_fp32=None (again, not applicable, all model weights remain FP32)
master_weights=None (not applicable, model weights remain FP32)
loss_scale="dynamic"


O2: “Almost FP16” Mixed Precision

O2将模型的权值转换为FP16,修补模型的前向方法,将输入数据转换为FP16,保持FP32中的批处理规范,维护FP32的主权值,更新优化器的param_groups,以便optimizer.step()直接作用于FP32的权值
(随后是FP32主重量-如有必要,>FP16型号重量拷贝),
并实现动态损失缩放(除非被覆盖)。与O1不同,O2不修补Torch函数或张量方法。

O2设置的默认属性:
cast_model_type=torch.float16
patch_torch_functions=False
keep_batchnorm_fp32=True
master_weights=True
loss_scale="dynamic"


O3: FP16 training

O3可能无法实现真正的混合精度选项O1和O2的稳定性。但是,为您的模型建立一个速度基线是很有用的,可以比较O1和O2的性能。如果您的模型使用批处理规范化,为了建立“光速”,您可以尝试使用带有附加属性override keep_batchnorm_fp32=True的O3(如前所述,它支持cudnn batchnorm)。

O3设置的默认属性:
cast_model_type=torch.float16
patch_torch_functions=False
keep_batchnorm_fp32=False
master_weights=False
loss_scale=1.0

注意⚠️:amp.initialize should be called after you have finished constructing your model(s) and optimizer(s), but before you send your model through any DistributedDataParallel wrapper. Currently, amp.initialize should only be called once.

参数:

Parameters
models (torch.nn.Module or list of torch.nn.Modules) – Models to modify/cast.

optimizers (optional, torch.optim.Optimizer or list of torch.optim.Optimizers) – Optimizers to modify/cast. REQUIRED for training, optional for inference.

enabled (bool, optional, default=True) – If False, renders all Amp calls no-ops, so your script should run as if Amp were not present.

opt_level (str, optional, default="O1") – Pure or mixed precision optimization level. Accepted values are “O0”, “O1”, “O2”, and “O3”, explained in detail above.

cast_model_type (torch.dtype, optional, default=None) – Optional property override, see above.

patch_torch_functions (bool, optional, default=None) – Optional property override.

keep_batchnorm_fp32 (bool or str, optional, default=None) – Optional property override. If passed as a string, must be the string “True” or “False”.

master_weights (bool, optional, default=None) – Optional property override.

loss_scale (float or str, optional, default=None) – Optional property override. If passed as a string, must be a string representing a number, e.g., “128.0”, or the string “dynamic”.

cast_model_outputs (torch.dpython:type, optional, default=None) – Option to ensure that the outputs of your model(s) are always cast to a particular type regardless of opt_level.

num_losses (int, optional, default=1) – Option to tell Amp in advance how many losses/backward passes you plan to use. When used in conjunction with the loss_id argument to amp.scale_loss, enables Amp to use a different loss scale per loss/backward pass, which can improve stability. See “Multiple models/optimizers/losses” under Advanced Amp Usage for examples. If num_losses is left to 1, Amp will still support multiple losses/backward passes, but use a single global loss scale for all of them.

verbosity (int, default=1) – Set to 0 to suppress Amp-related output.

min_loss_scale (float, default=None) – Sets a floor for the loss scale values that can be chosen by dynamic loss scaling. The default value of None means that no floor is imposed. If dynamic loss scaling is not used, min_loss_scale is ignored.

max_loss_scale (float, default=2.**24) – Sets a ceiling for the loss scale values that can be chosen by dynamic loss scaling. If dynamic loss scaling is not used, max_loss_scale is ignored.

Returns
Model(s) and optimizer(s) modified according to the opt_level. If either the models or optimizers args were lists, the corresponding return value will also be a list.

checkpoint

为了正确地保存和加载amp训练,我们引入了amp.state_dict(),它包含所有的loss_scalers及其相应的未跳过步骤,还引入了amp.load_state_dict()来恢复这些属性。
注意,我们建议使用相同的opt_level恢复模型。还要注意,我们建议在amp.initialize之后调用load_state_dict方法。

...
# Save checkpoint
checkpoint = {
    'model': model.state_dict(),
    'optimizer': optimizer.state_dict(),
    'amp': amp.state_dict()
}
torch.save(checkpoint, 'amp_checkpoint.pt')
...
amp.load_state_dict(checkpoint['amp'])
...

Advanced use cases:
统一的Amp API支持跨迭代的梯度累积、每次迭代的多次后向遍历、多个模型/优化器、自定义/用户定义的autograd函数和自定义数据批处理类。梯度裁剪和GANs也需要特殊的处理,但是这种处理不需要改变不同的opt_levels。

Transition guide for old API users:

我们强烈鼓励迁移到新的Amp API,因为它更多功能,更容易使用,并在未来的证明。原始的FP16_Optimizer和旧的“Amp”API都是不支持的,而且随时可能被移除。
以前通过amp_handle公开的函数现在可以通过amp模块访问。应该删除对amp_handle = amp.init()的任何现有调用。
详细内容请参照文档,此处不赘述。

3.2.apex.optimizers

待更