-
在PyTorch中实现最佳传输算法的工具包PyTorchOT
资源介绍
PyTorchOT
在PyTorch中实现接收器优化传输算法。 目前,已实现了Sinkhorn算法的两个版本:和对。 这段代码实际上只是从PyTorch中很棒的POT库( )中重新实现了一些实现。
用法示例:
from ot_pytorch import sink
M = pairwise_distance_matrix()
dist = sink(M, reg=5, cuda=False)
设置cuda = True启用cuda使用。
examples.py文件包含两个基本示例。
范例1:
让Z I〜统一[0,1],并定义数据X i =(0,Z I)中,Y I =(θ,Z i)中,对于i = 1,...,N和一些参数θ,其在[-1,1]上变化。 真正的最佳传输距离是|θ|。 该算法产生: