• TensorFlow.js 模型部署
    • 通过 TensorFlow.js 加载 Python 模型
    • 使用 TensorFlow.js 模型库

    TensorFlow.js 模型部署

    通过 TensorFlow.js 加载 Python 模型

    一般Tensorflow的模型,以Python版本为例,会被存储为以下四种格式之一:

    • TensorFlow SavedModel

    • Tensorflow Hub Module

    • Keras Module

    Google 目前最佳实践中,推荐使用 SavedModel 方法进行模型保存。同时所有以上格式,都可以通过 tensorflowjs-converter 转换器,将其转换为可以直接被 TensorFlow.js 加载的格式,在JavaScript语言中进行使用。

    tensorflowjs_converter 可以将Python存储的模型格式,转换为JavaScript可以直接调用的模型格式。

    安装 tensorflowjs_converter

    1. $ pip install tensorflowjs

    tensorflowjs_converter 的使用细节,可以通过 —help 参数查看程序帮助:

    1. $ tensorflowjs_converter --help

    以下我们以MobilenetV1为例,看一下如何对模型文件进行转换操作,并将可以被TensorFlow.js加载的模型文件,存放到 /mobilenet/tfjs_model 目录下。

    转换 SavedModel:将 /mobilenet/saved_model 转换到 /mobilenet/tfjs_model

    1. tensorflowjs_converter \
    2. --input_format=tf_saved_model \
    3. --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    4. --saved_model_tags=serve \
    5. /mobilenet/saved_model \
    6. /mobilenet/tfjs_model

    为了加载转换完成的模型文件,我们需要安装 tfjs-converter@tensorflow/tfjs 模块:

    1. $ npm install @tensorflow/tfjs

    然后,我们就可以通过JavaScript来加载Tensorflow模型了!

    1. import * as tf from '@tensorflow/tfjs';
    2.  
    3. const MODEL_URL = 'model_directory/model.json';
    4.  
    5. const model = await tf.loadGraphModel(MODEL_URL);
    6. // 对Keras或者tfjs原生的层模型,使用下面的加载函数:
    7. // const model = await tf.loadLayersModel(MODEL_URL);
    8.  
    9. const cat = document.getElementById('cat');
    10. model.execute(tf.browser.fromPixels(cat))

    使用 TensorFlow.js 模型库

    TensorFlow.js 提供了一系列预训练好的模型,方便大家快速的给自己的程序引入人工智能能力。

    模型库 GitHub 地址:<https://github.com/tensorflow/tfjs-models>,其中模型分类包括图像识别、语音识别、人体姿态识别、物体识别、文字分类等。

    由于这些API默认模型文件都存储在谷歌云上,直接使用会导致中国用户无法直接读取。在程序内使用模型API时要提供 modelUrl 的参数,可以指向谷歌中国的镜像服务器。

    谷歌云的base url是 https://storage.googleapis.com, 中国镜像的base url是 https://www.gstaticcnapps.cn 模型的url path是一致的。以 posenet模型为例:

    • 谷歌云地址是:https://storage.googleapis.com/tfjs-models/savedmodel/posenet/mobilenet/float/050/model-stride16.json

    • 中国镜像地址是:https://www.gstaticcnapps.cn/tfjs-models/savedmodel/posenet/mobilenet/float/050/model-stride16.json