MXNet的基本数据操作

mac2025-05-03  5

在MXNet中,NDArray是一个类,也是存储和变换数据的主要工具。

创建NDArray

使用arrange()函数创建一个行向量,返回一个NDArray实例,其中包含一个指定长度的一维数组。使用zeros()和ones()创建指定元素值的NDArray。通过Python的列表(list)指定需要创建的NDArray中每个元素的值。随机生成NDArray中每个元素的值,例如nd.random.normal()。

NDArray属性

shape:获取NDArray实例的形状。size:获取NDArray实例中元素的总数。reshape()函数改变NDArray的形状。

运算

按元素的+、-、*、/、exp()。矩阵乘法dot()。连结(concatenate)concat()。条件判别式得到元素为0或1的心的NDArray。asscalar()函数将结果变为Python中的标量。

广播机制

对两个形状不同的NDArray按元素运算是,会触发广播(broadcasting)机制。

先适当复制元素使这两个NDArray形状相同后再按元素运算。

内存开销

Y = X + Y操作中,需要新开内存存储Y和X+Y。

Z = Y.zeros_like() Z[:] = X + Y

通过zeros_like()创建和Y形状相同且元素为0的NDArray,记为Z。使用[:]将X+Y的结果写进Z对应的内存中。但是临时开辟内存存储X+Y。可以使用运算符全名函数中的out参数。

nd.elewise_add(X, Y, out=Z)

自动求梯度

a = nd.random.normal(shape=1) a.attach_grad() with autograd.record(): c = f(a) c.backward() attach_grad():申请存储梯度所需要的内存。record():记录与求梯度相关的计算。backward():自动求导。
最新回复(0)