在使用numpy多维数组时我们常会需要获取数组中的元素,这一般有两种方法:
import numpy as np a = np.random.randint(10, 20, size=[10, 20]) print(a) print(a[2, 2]) print(a[2][2]) ''' [[17 12 16 14 19 10 14 13 15 13 17 19 19 11 11 18 16 12 16 17] [15 17 15 11 19 16 19 18 12 12 19 19 15 19 18 11 18 12 10 13] [10 12 13 12 14 11 12 12 10 18 14 16 16 16 14 13 11 15 11 15] [10 17 19 11 16 17 15 11 14 12 17 14 15 17 12 17 16 15 10 14] [13 14 13 16 14 18 14 16 16 10 11 13 14 14 12 11 18 12 14 13] [17 19 14 15 19 12 10 17 14 13 19 11 17 13 17 10 19 14 18 11] [17 11 13 18 14 17 14 18 11 18 18 14 16 19 18 18 15 18 15 12] [19 17 10 13 14 12 19 16 10 18 11 11 12 18 16 15 15 13 15 19] [18 12 11 15 11 13 13 18 15 19 19 13 16 13 19 15 10 12 10 15] [14 19 12 10 11 10 14 19 12 10 19 12 18 15 18 17 19 12 18 14]] 13 13 '''但当我们需要用到切片时,第二种写法却是错误的:
import numpy as np a = np.random.randint(10, 20, size=[10, 20]) print(a) print(a[:2, :2]) print(a[:2][:2]) ''' [[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17] [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18] [12 10 10 12 13 18 10 13 14 13 13 19 10 16 13 19 13 19 13 17] [17 19 11 10 11 17 16 10 18 15 10 18 12 15 17 11 16 13 12 11] [17 19 11 13 12 15 14 16 12 12 14 11 15 13 19 19 17 14 16 19] [10 15 18 19 10 15 12 13 11 18 19 11 14 15 14 17 15 10 13 16] [10 16 17 18 19 14 15 10 14 11 11 16 18 15 16 12 10 11 16 18] [13 12 13 13 16 16 17 16 15 15 15 16 14 16 15 16 19 14 19 14] [16 13 17 16 10 15 19 15 19 13 19 12 16 11 14 17 18 19 15 15] [14 10 19 14 10 11 16 14 10 16 18 12 10 14 12 10 12 14 10 15]] [[16 16] [16 13]] [[16 16 12 11 11 16 13 13 18 17 12 10 15 12 19 12 19 18 11 17] [16 13 15 12 14 18 18 19 15 16 10 17 19 15 15 18 14 17 17 18]] '''第一种写法获取到了我们实际想要的子矩阵,而第二种写法实际上需要分开来看待:先获取a的前两行得到一个子矩阵,再获取这个子矩阵的前两行。 最近写代码时总弄混这两个写法,因此记录一下,numpy切片的正确用法是用逗号隔开,而不是像多维数组索引那样隔开。
今天又发现了新的问题,numpy真是有趣。在做cs231n的作业时,我需要从一个 N ∗ C N * C N∗C的分数中按照一个 N ∗ 1 N * 1 N∗1的label来取出 N ∗ 1 N * 1 N∗1的正确分数(每行按照label选一个分数),自然会想到花式索引和切片结合的方法,但遇到了一些问题,这里总结一下可能的写法:
a = np.random.randint(5, 10, size=(5, 10)) print(a) y1 = np.random.randint(0, 10, size=(5, )) print(y1) y2 = np.random.randint(0, 10, size=(5, 1)) print(y2) ''' [[7 6 5 7 9 5 9 6 8 8] [8 8 8 7 9 9 8 5 9 8] [8 9 8 7 5 7 7 6 8 8] [6 9 5 9 5 7 7 8 7 7] [9 6 6 9 9 6 7 9 7 6]] [8 1 0 9 8] [[5] [3] [5] [9] [5]] '''可以看到y1和y2的shape是不一样的,y1是一个数组,y2则是一个二维矩阵。
print(a[:, y1]) ''' [[8 6 7 8 8] [9 8 8 8 9] [8 9 8 8 8] [7 9 6 7 7] [7 6 9 6 7]] 可以看到,这种写法得到一个N * N的矩阵,每一行对应a的每一行按照y1的所有元素来取值, 即本来a的每一行取一个值就可以,但是却取了N个值,每一行相当于a[i, y1] a[0, y1] = [8 6 7 8 8] ''' print(a[range(5), y1]) ''' [8 8 8 7 7] 这种写法就是我们想要的结果 ''' print(a[:, y2]) ''' [[[5] [7] [5] [8] [5]] [[9] [7] [9] [8] [9]] [[7] [7] [7] [8] [7]] [[7] [9] [7] [7] [7]] [[6] [9] [6] [6] [6]]] 这种写法得到的结果更加离谱,是一个5 * 5 * 1的三维矩阵, 每个5 * 1的子矩阵相当于第一种写法的结果 ''' print(a[range(5), y2]) ''' [[5 9 7 7 6] [7 7 7 9 9] [5 9 7 7 6] [8 8 8 7 6] [5 9 7 7 6]] 一共5行,每一行都是a的某一列,按照y2来取值。 '''结论就是要使用range和一维数组来进行切片和花式索引。
切片中省略号…的作用 有时候我们会看到这样的索引写法a[..., 1:],其中的...是一种特殊写法,常适用于高维数组的切片。比如,当a是5维数组时,a[:, :, :, :, 1:] = a[..., 1:],即...是所有完整切片的缩短,这种写法更简洁。