Pytorch计算距离(例如欧式距离)torch.nn.PairwiseDistance
通常,我们计算欧式距离,例如[0,0]到[1,1]的距离为2\sqrt22。pdist = nn.PairwiseDistance(p=2)#p=2就是计算欧氏距离,p=1就是曼哈顿距离,例如上面的例子,距离是1.input1 = torch.randn(100, 128)input2 = torch.randn(100, 128)#上面两个形状要一样。output = pdist(input1
·
通常,我们计算欧式距离,例如[0,0]到[1,1]的距离为 2 \sqrt2 2。
pdist = nn.PairwiseDistance(p=2)#p=2就是计算欧氏距离,p=1就是曼哈顿距离,例如上面的例子,距离是1.
input1 = torch.randn(100, 128)
input2 = torch.randn(100, 128)#上面两个形状要一样。
output = pdist(input1,input2)#计算各自每一行之间的欧式距离。
output.shape
torch.Size([100])
注意:
- 两个输入形状要一样。因为行之间需要相互计算距离。
- 形状必须是[N,D],不能是[D],后者需要改成[1,D]否则报错。
更多推荐
已为社区贡献2条内容
所有评论(0)