相关阅读
Pytorch基础![]()
讯享网https://blog.csdn.net/weixin_45791458/category_12457644.html?spm=1001.2014.3001.5482
在Pytorch中,reshape是Tensor的一个重要方法,它与Numpy中的reshape类似,用于返回一个改变了形状但数据和数据顺序和原来一致的新Tensor对象。注意:此时返回的新tensor中的数据对象并不一定是新的,这取决于应用此方法的Tensor是否是连续的。
reshape方法的语法如下所示:
Tensor.reshape(*shape) → Tensor shape (tuple of ints or int...) - the desired shape
讯享网
reshape的用法如下所示:

讯享网import torch # 创建一个张量 x = torch.randn(3, 4) tensor([[ 0.1961, -0.9038, 0.9196, -1.1851], [ 1.1321, 0.3153, 0.3485, 0.7977], [-0.5279, 0.2062, -0.4224, -0.3993]]) # 使用reshape方法将其重新塑造为2行6列的形状 y = x.reshape(2, 6) y = x.reshape((2,6)) #两种形式均可,y = x.reshape([2,6])也可 tensor([[ 0.1961, -0.9038, 0.9196, -1.1851, 1.1321, 0.3153], [ 0.3485, 0.7977, -0.5279, 0.2062, -0.4224, -0.3993]])
可以看到,给出的参数既可以是多个整数(其中每个整数代表一个维度的大小,而整数的数量代表维度的数量),也可以是一个元组或是列表(其中每个元素代表一个维度的大小,而元素数量代表维度的数量)。而且reshape不改变Tensor中数据的排列顺序(指的是从上到下从左到右遍历的顺序),只改变形状,这也就对reshape各维度大小的乘积有要求,要与原Tensor一致。在上例中即3*4=2*6。
另外reshape还有一个trick,即某一维的实参可以是-1,此时会自动根据原Tensor大小和给出的其他维度参数的大小,推断出这一维度的大小,举例如下:
import torch # 创建一个张量 x = torch.randn(3, 4) tensor([[ 0.1961, -0.9038, 0.9196, -1.1851], [ 1.1321, 0.3153, 0.3485, 0.7977], [-0.5279, 0.2062, -0.4224, -0.3993]]) # 使用reshape方法将其重新塑造为6行n列的形状,n为自动推断出的值 y = x.reshape(6, -1) tensor([[ 0.1961, -0.9038], [ 0.9196, -1.1851], [ 1.1321, 0.3153], [ 0.3485, 0.7977], [-0.5279, 0.2062], [-0.4224, -0.3993]]) # 使用reshape方法将其重新塑造为(2,2,n)的形状,n为自动推断出的值 y = x.reshape(2, 2, -1) tensor([[[ 0.1961, -0.9038, 0.9196], [-1.1851, 1.1321, 0.3153]], [[ 0.3485, 0.7977, -0.5279], [ 0.2062, -0.4224, -0.3993]]]) # 不能在两个维度都指定-1,这时无法推断出唯一结果 y = x.reshape(2, -1, -1) Traceback (most recent call last): File "<stdin>", line 1, in <module> RuntimeError: only one dimension can be inferred
除此之外,还可以使用torch.reshape()函数,这与使用reshape方式效果一致,torch.reshape()的语法如下所示。
讯享网torch.reshape(input, shape) → Tensor input (Tensor) – the tensor to be reshaped shape (tuple of python:int) – the new shape import torch # 创建一个张量 x = torch.randn(3, 4) tensor([[ 0.1961, -0.9038, 0.9196, -1.1851], [ 1.1321, 0.3153, 0.3485, 0.7977], [-0.5279, 0.2062, -0.4224, -0.3993]]) # 使用reshape函数将其重新塑造为6行n列的形状,n为自动推断出的值 y = torch.reshape(x, (6, -1)) tensor([[ 0.1961, -0.9038], [ 0.9196, -1.1851], [ 1.1321, 0.3153], [ 0.3485, 0.7977], [-0.5279, 0.2062], [-0.4224, -0.3993]])
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/33856.html