torch.roll 函数官方解释
翻译
torch.roll(input, shifts, dims=None) → Tensor
- input (Tensor) —— 输入张量。
- shifts (python:int 或 tuple of python:int) —— 张量元素移位的位数。如果该参数是一个元组(例如shifts=(x,y)),dims必须是一个相同大小的元组(例如dims=(a,b)),相当于在第a维度移x位,在b维度移y位
- dims (int 或 tuple of python:int) 确定的维度。
沿给定维数滚动张量,移动到最后一个位置以外的元素将在第一个位置重新引入。如果没有指定尺寸,张量将在轧制前被压平,然后恢复到原始形状。
官方例子
>>> x = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]).view(4, 2) >>> x tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) '''第0维度向下移1位,多出的[7,8]补充到顶部''' >>> torch.roll(x, 1, 0) tensor([[7, 8], [1, 2], [3, 4], [5, 6]]) '''第0维度向上移1位,多出的[1,2]补充到底部''' >>> torch.roll(x, -1, 0) tensor([[3, 4], [5, 6], [7, 8], [1, 2]]) '''tuple元祖,维度一一对应: 第0维度向下移2位,多出的[5,6][7,8]补充到顶部, 第1维向右移1位,多出的[6,8,2,4]补充到最左边''' >>> torch.roll(x, shifts=(2, 1), dims=(0, 1)) tensor([[6, 5], [8, 7], [2, 1], [4, 3]])
讯享网
简单理解:shifts的值为正数相当于向下挤牙膏,挤出的牙膏又从顶部塞回牙膏里面;shifts的值为负数相当于向上挤牙膏,挤出的牙膏又从底部塞回牙膏里面
以下一个多维张量的例子(参考swin transformer论文源码):
torch.roll(x, shifts=(-20, -20), dims=(1, 2))

完整代码
讯享网import torch import numpy as np import matplotlib.pyplot as plt shift_size = 3 '''构造多维张量''' x=np.arange().reshape(1,56,56,96) x=torch.from_numpy(x) if shift_size > 0: shifted_x = torch.roll(x, shifts=(-20, -20), dims=(1, 2)) #shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) print("---------经过循环位移了---------") else: shifted_x = x '''可视化部分''' plt.figure(figsize=(16,8)) plt.subplot(1,2,1) plt.imshow(x[0,:,:,0]) plt.title("orgin_img") plt.subplot(1,2,2) plt.imshow(shifted_x[0,:,:,0]) if torch.equal(shifted_x, x): plt.title("non_shifted") else: plt.title("shifted_img") plt.show() plt.pause(5) plt.close()

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/14972.html