参考博文:https://blog.csdn.net/SHU15121856/article/details/87810372
从前往后提供的索引,会依次在每个维度上做索引。
import torch a = torch.rand(4, 3, 28, 28) print(a[0].shape) print(a[0, 0].shape) print(a[0, 0, 2, 4]) # 具体到某个元素运行结果:
torch.Size([3, 28, 28]) torch.Size([28, 28]) tensor(0.3690)start :end,取的索引从start到end-1 注意负值的索引即表示倒数第几个元素,-2就是倒数第二个元素。
print(a[:2].shape) # 在第一个维度上取后0和1 print(a[:2, :1, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取0 print(a[:2, 1:, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2 print(a[:2, -2:, :, :].shape) # 在第一个维度上取0和1,在第二个维度上取1,2运行结果:
torch.Size([2, 3, 28, 28]) torch.Size([2, 1, 28, 28]) torch.Size([2, 2, 28, 28]) torch.Size([2, 2, 28, 28])运行结果:
torch.Size([4, 3, 14, 14]) torch.Size([4, 3, 14, 14])选择特定下标有时候很有用,比如上面的a这个Tensor可以看作4张RGB(3通道)的MNIST图像,长宽都是28px。那么在第一维度上可以选择特定的图片,在第二维度上选择特定的通道。
# 选择第一张和第三张图 print(a.index_select(0, torch.tensor([0, 2])).shape) # 选择R通道和B通道 print(a.index_select(1, torch.tensor([0, 2])).shape)运行结果:
torch.Size([2, 3, 28, 28]) torch.Size([4, 2, 28, 28])在索引中使用...可以表示任意多的维度。
print(a[...].shape) print(a[:, 1, ...].shape) print(a[..., :2].shape) print(a[0, ..., ::2].shape)运行结果:
torch.Size([4, 3, 28, 28]) torch.Size([4, 28, 28]) torch.Size([4, 3, 28, 2]) torch.Size([3, 28, 14])可以获取满足一些条件的值的位置索引,然后用这个索引去取出这些位置的元素。
import torch # 取出a这个Tensor中大于0.5的元素 a = torch.randn(3, 4) print(a) x = a.ge(0.5) print(x) print(a[x])运行结果:
tensor([[ 0.1638, 0.9582, -0.2464, -0.8064], [ 1.8385, -2.0180, 0.8382, 1.0563], [ 0.1587, -1.6653, -0.2057, 0.1316]]) tensor([[0, 1, 0, 0], [1, 0, 1, 1], [0, 0, 0, 0]], dtype=torch.uint8) tensor([0.9582, 1.8385, 0.8382, 1.0563])take索引是基于目标Tensor的flatten形式下的,即摊平后的Tensor的索引。
import torch a = torch.tensor([[3, 7, 2], [2, 8, 3]]) print(a) print(torch.take(a, torch.tensor([0, 1, 5])))运行结果:
tensor([[3, 7, 2], [2, 8, 3]]) tensor([3, 7, 3])