• TensorFlow.js 模型训练 *

    TensorFlow.js 模型训练 *

    与 TensorFlow Serving 和 TensorFlow Lite 不同,TensorFlow.js 不仅支持模型的部署和推断,还支持直接在 TensorFlow.js 中进行模型训练、

    在 TensorFlow 基础章节中,我们已经用 Python 实现过,针对某城市在2013年-2017年的房价的任务,通过对该数据进行线性回归,即使用线性模型 y = ax + b 来拟合上述数据,此处 ab 是待求的参数。

    下面我们改用 TensorFlow.js 来实现一个 JavaScript 版本。

    首先,我们定义数据,进行基本的归一化操作。

    1. import * as tf from '@tensorflow/tfjs'
    2.  
    3. const xsRaw = tf.tensor([2013, 2014, 2015, 2016, 2017])
    4. const ysRaw = tf.tensor([12000, 14000, 15000, 16500, 17500])
    5.  
    6. // 归一化
    7. const xs = xsRaw.sub(xsRaw.min())
    8. .div(xsRaw.max().sub(xsRaw.min()))
    9. const ys = ysRaw.sub(ysRaw.min())
    10. .div(ysRaw.max().sub(ysRaw.min()))

    接下来,我们来求线性模型中两个参数 ab 的值。

    使用 loss() 计算损失;使用 optimizer.minimize() 自动更新模型参数。

    1. const a = tf.scalar(Math.random()).variable()
    2. const b = tf.scalar(Math.random()).variable()
    3.  
    4. // y = a * x + b.
    5. const f = (x: tf.Tensor) => a.mul(x).add(b)
    6. const loss = (pred: tf.Tensor, label: tf.Tensor) => pred.sub(label).square().mean() as tf.Scalar
    7.  
    8. const learningRate = 1e-3
    9. const optimizer = tf.train.sgd(learningRate)
    10.  
    11. // 训练模型
    12. for (let i = 0; i < 10000; i++) {
    13. optimizer.minimize(() => loss(f(xs), ys))
    14. }
    15.  
    16. // 预测
    17. console.log(`a: ${a.dataSync()}, b: ${b.dataSync()}`)
    18. const preds = f(xs).dataSync() as Float32Array
    19. const trues = ys.arraySync() as number[]
    20. preds.forEach((pred, i) => {
    21. console.log(`x: ${i}, pred: ${pred.toFixed(2)}, true: ${trues[i].toFixed(2)}`)
    22. })

    从下面的输出样例中我们可以看到,已经拟合的比较接近了。

    1. a: 0.9339302778244019, b: 0.08108722418546677
    2. x: 0, pred: 0.08, true: 0.00
    3. x: 1, pred: 0.31, true: 0.36
    4. x: 2, pred: 0.55, true: 0.55
    5. x: 3, pred: 0.78, true: 0.82
    6. x: 4, pred: 1.02, true: 1.00

    可以直接在浏览器中运行,完整的 HTML 代码如下:

    1. <html>
    2. <head>
    3. <script src="http://unpkg.com/@tensorflow/tfjs/dist/tf.min.js"></script>
    4. <script>
    5. const xsRaw = tf.tensor([2013, 2014, 2015, 2016, 2017])
    6. const ysRaw = tf.tensor([12000, 14000, 15000, 16500, 17500])
    7.  
    8. // 归一化
    9. const xs = xsRaw.sub(xsRaw.min())
    10. .div(xsRaw.max().sub(xsRaw.min()))
    11. const ys = ysRaw.sub(ysRaw.min())
    12. .div(ysRaw.max().sub(ysRaw.min()))
    13. const a = tf.scalar(Math.random()).variable()
    14. const b = tf.scalar(Math.random()).variable()
    15.  
    16. // y = a * x + b.
    17. const f = (x) => a.mul(x).add(b)
    18. const loss = (pred, label) => pred.sub(label).square().mean()
    19.  
    20. const learningRate = 1e-3
    21. const optimizer = tf.train.sgd(learningRate)
    22.  
    23. // 训练模型
    24. for (let i = 0; i < 10000; i++) {
    25. optimizer.minimize(() => loss(f(xs), ys))
    26. }
    27.  
    28. // 预测
    29. console.log(`a: ${a.dataSync()}, b: ${b.dataSync()}`)
    30. const preds = f(xs).dataSync()
    31. const trues = ys.arraySync()
    32. preds.forEach((pred, i) => {
    33. console.log(`x: ${i}, pred: ${pred.toFixed(2)}, true: ${trues[i].toFixed(2)}`)
    34. })
    35. </script>
    36. </head>
    37. </html>