- 3.3. 线性回归的简洁实现
- 3.3.1. 生成数据集
- 3.3.2. 读取数据
- 3.3.3. 定义模型
- 3.3.4. 初始化模型参数
- 3.3.5. 定义损失函数
- 3.3.6. 定义优化算法
- 3.3.7. 训练模型
- 3.3.8. 小结
- 3.3.9. 练习
3.3. 线性回归的简洁实现
随着深度学习框架的发展,开发深度学习应用变得越来越便利。实践中,我们通常可以用比上一节更简洁的代码来实现同样的模型。在本节中,我们将介绍如何使用MXNet提供的Gluon接口更方便地实现线性回归的训练。
3.3.1. 生成数据集
我们生成与上一节中相同的数据集。其中features
是训练数据特征,labels
是标签。
- In [1]:
- from mxnet import autograd, nd
- num_inputs = 2
- num_examples = 1000
- true_w = [2, -3.4]
- true_b = 4.2
- features = nd.random.normal(scale=1, shape=(num_examples, num_inputs))
- labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
- labels += nd.random.normal(scale=0.01, shape=labels.shape)
3.3.2. 读取数据
Gluon提供了data
包来读取数据。由于data
常用作变量名,我们将导入的data
模块用添加了Gluon首字母的假名gdata
代替。在每一次迭代中,我们将随机读取包含10个数据样本的小批量。
- In [2]:
- from mxnet.gluon import data as gdata
- batch_size = 10
- # 将训练数据的特征和标签组合
- dataset = gdata.ArrayDataset(features, labels)
- # 随机读取小批量
- data_iter = gdata.DataLoader(dataset, batch_size, shuffle=True)
这里data_iter
的使用跟上一节中的一样。让我们读取并打印第一个小批量数据样本。
- In [3]:
- for X, y in data_iter:
- print(X, y)
- break
- [[ 0.80084276 1.6435156 ]
- [-0.64220035 0.77557075]
- [-0.5541078 -0.02885367]
- [-2.338083 -0.32677847]
- [-1.3031056 0.64198005]
- [ 0.31445166 -0.3114319 ]
- [ 0.43465078 0.02699922]
- [ 1.1807663 1.4896834 ]
- [ 2.0085082 -1.3542701 ]
- [-0.52480155 0.3005414 ]]
- <NDArray 10x2 @cpu(0)>
- [ 0.1861284 0.2710672 3.1786947 0.6198239 -0.57668865 5.881366
- 4.966181 1.5095915 12.808897 2.1383817 ]
- <NDArray 10 @cpu(0)>
3.3.3. 定义模型
在上一节从零开始的实现中,我们需要定义模型参数,并使用它们一步步描述模型是怎样计算的。当模型结构变得更复杂时,这些步骤将变得更繁琐。其实,Gluon提供了大量预定义的层,这使我们只需关注使用哪些层来构造模型。下面将介绍如何使用Gluon更简洁地定义线性回归。
首先,导入nn
模块。实际上,“nn”是neuralnetworks(神经网络)的缩写。顾名思义,该模块定义了大量神经网络的层。我们先定义一个模型变量net
,它是一个Sequential
实例。在Gluon中,Sequential
实例可以看作是一个串联各个层的容器。在构造模型时,我们在该容器中依次添加层。当给定输入数据时,容器中的每一层将依次计算并将输出作为下一层的输入。
- In [4]:
- from mxnet.gluon import nn
- net = nn.Sequential()
回顾图3.1中线性回归在神经网络图中的表示。作为一个单层神经网络,线性回归输出层中的神经元和输入层中各个输入完全连接。因此,线性回归的输出层又叫全连接层。在Gluon中,全连接层是一个Dense
实例。我们定义该层输出个数为1。
- In [5]:
- net.add(nn.Dense(1))
值得一提的是,在Gluon中我们无须指定每一层输入的形状,例如线性回归的输入个数。当模型得到数据时,例如后面执行net(X)
时,模型将自动推断出每一层的输入个数。我们将在之后“深度学习计算”一章详细介绍这种机制。Gluon的这一设计为模型开发带来便利。
3.3.4. 初始化模型参数
在使用net
前,我们需要初始化模型参数,如线性回归模型中的权重和偏差。我们从MXNet导入init
模块。该模块提供了模型参数初始化的各种方法。这里的init
是initializer
的缩写形式。我们通过init.Normal(sigma=0.01)
指定权重参数每个元素将在初始化时随机采样于均值为0、标准差为0.01的正态分布。偏差参数默认会初始化为零。
- In [6]:
- from mxnet import init
- net.initialize(init.Normal(sigma=0.01))
3.3.5. 定义损失函数
在Gluon中,loss
模块定义了各种损失函数。我们用假名gloss
代替导入的loss
模块,并直接使用它提供的平方损失作为模型的损失函数。
- In [7]:
- from mxnet.gluon import loss as gloss
- loss = gloss.L2Loss() # 平方损失又称L2范数损失
3.3.6. 定义优化算法
同样,我们也无须实现小批量随机梯度下降。在导入Gluon后,我们创建一个Trainer
实例,并指定学习率为0.03的小批量随机梯度下降(sgd
)为优化算法。该优化算法将用来迭代net
实例所有通过add
函数嵌套的层所包含的全部参数。这些参数可以通过collect_params
函数获取。
- In [8]:
- from mxnet import gluon
- trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.03})
3.3.7. 训练模型
在使用Gluon训练模型时,我们通过调用Trainer
实例的step
函数来迭代模型参数。上一节中我们提到,由于变量l
是长度为batch_size
的一维NDArray
,执行l.backward()
等价于执行l.sum().backward()
。按照小批量随机梯度下降的定义,我们在step
函数中指明批量大小,从而对批量中样本梯度求平均。
- In [9]:
- num_epochs = 3
- for epoch in range(1, num_epochs + 1):
- for X, y in data_iter:
- with autograd.record():
- l = loss(net(X), y)
- l.backward()
- trainer.step(batch_size)
- l = loss(net(features), labels)
- print('epoch %d, loss: %f' % (epoch, l.mean().asnumpy()))
- epoch 1, loss: 0.040774
- epoch 2, loss: 0.000153
- epoch 3, loss: 0.000051
下面我们分别比较学到的模型参数和真实的模型参数。我们从net
获得需要的层,并访问其权重(weight
)和偏差(bias
)。学到的参数和真实的参数很接近。
- In [10]:
- dense = net[0]
- true_w, dense.weight.data()
- Out[10]:
- ([2, -3.4],
- [[ 1.9993649 -3.399639 ]]
- <NDArray 1x2 @cpu(0)>)
- In [11]:
- true_b, dense.bias.data()
- Out[11]:
- (4.2,
- [4.199203]
- <NDArray 1 @cpu(0)>)
3.3.8. 小结
- 使用Gluon可以更简洁地实现模型。
- 在Gluon中,
data
模块提供了有关数据处理的工具,nn
模块定义了大量神经网络的层,loss
模块定义了各种损失函数。 - MXNet的
initializer
模块提供了模型参数初始化的各种方法。
3.3.9. 练习
- 如果将
l = loss(net(X), y)
替换成l = loss(net(X), y).mean()
,我们需要将trainer.step(batch_size)
相应地改成trainer.step(1)
。这是为什么呢? - 查阅MXNet文档,看看
gluon.loss
和init
模块里提供了哪些损失函数和初始化方法。 - 如何访问
dense.weight
的梯度?