-
工具:PyTorch扩展,旨在Pytorch中提供便捷的混合精度和分布式训练功能
资源介绍
介绍
该存储库包含NVIDIA维护的实用程序,可简化Pytorch中的混合精度和分布式培训。 这里的某些代码最终将包含在上游Pytorch中。 Apex的目的是使用户尽快使用最新的实用程序。
完整的API文档: :
和幻灯片
内容
1.放大器:自动混合精度
apex.amp是通过仅更改脚本的3行来启用混合精度训练的工具。 用户可以通过提供不同的标志进行amp.initialize轻松地尝试不同的纯精度和混合精度训练模式。
(标志cast_batchnorm已重命名为keep_batchnorm_fp32 )。
(适用于已弃用的“ Amp”和“ FP16_Optimizer” API的用户)
2.分布式培训
apex.parallel.DistributedDataParallel是一个模块包装器,类似于torch.nn.parallel.DistributedDataParall