• 自动求导机制

    自动求导机制

    在机器学习中,我们经常需要计算函数的导数。TensorFlow提供了强大的 自动求导机制 来计算导数。以下代码展示了如何使用 tf.GradientTape() 计算函数 y(x) = x^2x = 3 时的导数:

    1. import tensorflow as tf
    2.  
    3. x = tf.Variable(initial_value=3.)
    4. with tf.GradientTape() as tape: # 在 tf.GradientTape() 的上下文内,所有计算步骤都会被记录以用于求导
    5. y = tf.square(x)
    6. y_grad = tape.gradient(y, x) # 计算y关于x的导数
    7. print([y, y_grad])

    输出:

    1. [array([9.], dtype=float32), array([6.], dtype=float32)]

    这里 x 是一个初始化为3的 变量 (Variable),使用 tf.Variable() 声明。与普通张量一样,变量同样具有形状、类型和值三种属性。使用变量需要有一个初始化过程,可以通过在 tf.Variable() 中指定 initial_value 参数来指定初始值。这里将变量 x 初始化为 3. 1。变量与普通张量的一个重要区别是其默认能够被TensorFlow的自动求导机制所求导,因此往往被用于定义机器学习模型的参数。

    tf.GradientTape() 是一个自动求导的记录器,在其中的变量和计算步骤都会被自动记录。在上面的示例中,变量 x 和计算步骤 y = tf.square(x) 被自动记录,因此可以通过 y_grad = tape.gradient(y, x) 求张量 y 对变量 x 的导数。

    在机器学习中,更加常见的是对多元函数求偏导数,以及对向量或矩阵的求导。这些对于TensorFlow也不在话下。以下代码展示了如何使用 tf.GradientTape() 计算函数 L(w, b) = \|Xw + b - y\|^2w = (1, 2)^T, b = 1 时分别对 w, b 的偏导数。其中 X = \begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}, y = \begin{bmatrix} 1 \\ 2\end{bmatrix}

    1. X = tf.constant([[1., 2.], [3., 4.]])
    2. y = tf.constant([[1.], [2.]])
    3. w = tf.Variable(initial_value=[[1.], [2.]])
    4. b = tf.Variable(initial_value=1.)
    5. with tf.GradientTape() as tape:
    6. L = 0.5 * tf.reduce_sum(tf.square(tf.matmul(X, w) + b - y))
    7. w_grad, b_grad = tape.gradient(L, [w, b]) # 计算L(w, b)关于w, b的偏导数
    8. print([L.numpy(), w_grad.numpy(), b_grad.numpy()])

    输出:

    1. [62.5, array([[35.],
    2. [50.]], dtype=float32), array([15.], dtype=float32)]

    这里, tf.square() 操作代表对输入张量的每一个元素求平方,不改变张量形状。 tf.reduce_sum() 操作代表对输入张量的所有元素求和,输出一个形状为空的纯量张量(可以通过 axis 参数来指定求和的维度,不指定则默认对所有元素求和)。TensorFlow中有大量的张量操作API,包括数学运算、张量形状操作(如 tf.reshape())、切片和连接(如 tf.concat())等多种类型,可以通过查阅TensorFlow的官方API文档 2 来进一步了解。

    从输出可见,TensorFlow帮助我们计算出了

    L((1, 2)^T, 1) &= 62.5\frac{\partial L(w, b)}{\partial w} |_{w = (1, 2)^T, b = 1} &= \begin{bmatrix} 35 \\ 50\end{bmatrix}\frac{\partial L(w, b)}{\partial b} |_{w = (1, 2)^T, b = 1} &= 15