-
RotoGrad是官方Pytorch实现的工具
资源介绍
罗托格拉德
Pytorch中用于多任务学习的动态梯度均质化库
安装
安装此库就像在终端中运行一样简单
pip install rotograd
该代码已经在Pytorch 1.7.0中进行了测试,但是它应该可以在大多数版本上使用。 如果不是这种情况,请随时提出一个问题。
概述
这是RotoGrad的官方Pytorch实现,该算法可减少当多任务学习系统的不同任务争夺共享资源时由于与共享参数有关的梯度冲突而导致的负迁移。
假设您有一个硬参数共享体系结构,其中跨任务共享一个backbone模型,并且要解决两个不同的任务。 这些任务采取骨架的输出z = backbone(x)并将其馈送到一个任务特异性模型( head1和head2 ),以获得其任务的预测,即, y1 = head1(z)和y2 = head2(z) 。
然后,只需将所有零件放到一个模型中,就可以简单地使用RotoGrad或