• TensorFlow 集成入门
    • 编译 MLeap-TensorFlow 模块
    • 使用 MLeap-TensorFlow

    TensorFlow 集成入门

    MLeap Tensorflow 集成允许用户将 TensorFlow Graph 当做 Transformer 集成到 ML Pipeline 中。在未来我们会提升对 TensorFlow 的兼容性。目前来说,TensorFlow 与 MLeap 的集成还是一个实验性功能,我们仍在进一步稳定这个特性。

    编译 MLeap-TensorFlow 模块

    MLeap TensorFlow 模块未被托管在 Maven Central 上,用户必须借助 TensorFlow 提供的 JNI(Java Native Interface)支持,编译源码获得。参考相关教程从源码编译 TensorFlow 模块。

    使用 MLeap-TensorFlow

    编译工作就绪之后,你就能轻松将 TensorFlow 集成到 MLeap Pipeline 中。

    首先,添加 MLeap-TensorFlow 作为项目依赖。

    1. libraryDependencies += "ml.combust.mleap" %% "mleap-tensorflow" % "0.13.0"

    接下来就能在代码中使用 Tensor Graph。让我们构建一个包含两个 Tensor 的简单 Graph。

    1. import ml.combust.mleap.core.types._
    2. import ml.combust.mleap.runtime.frame.{DefaultLeapFrame, Row}
    3. import ml.combust.mleap.tensor.Tensor
    4. import ml.combust.mleap.tensorflow.{TensorflowModel, TensorflowTransformer}
    5. import org.tensorflow
    6. // Initialize our Tensorflow demo graph
    7. val graph = new tensorflow.Graph
    8. // Build placeholders for our input values
    9. val inputA = graph.opBuilder("Placeholder", "InputA").
    10. setAttr("dtype", tensorflow.DataType.FLOAT).
    11. build()
    12. val inputB = graph.opBuilder("Placeholder", "InputB").
    13. setAttr("dtype", tensorflow.DataType.FLOAT).
    14. build()
    15. // Multiply the two placeholders and put the result in
    16. // The "MyResult" tensor
    17. graph.opBuilder("Mul", "MyResult").
    18. setAttr("T", tensorflow.DataType.FLOAT).
    19. addInput(inputA.output(0)).
    20. addInput(inputB.output(0)).
    21. build()
    22. // Build the MLeap model wrapper around the Tensorflow graph
    23. val model = TensorflowModel(graph,
    24. // Must specify inputs and input types for converting to TF tensors
    25. inputs = Seq(("InputA", TensorType.Float()), ("InputB", TensorType.Float())),
    26. // Likewise, specify the output values so we can convert back to MLeap
    27. // Types properly
    28. outputs = Seq(("MyResult", TensorType.Float())))
    29. // Connect our Leap Frame values to the Tensorflow graph
    30. // Inputs and outputs
    31. val shape = NodeShape().
    32. // Column "input_a" gets sent to the TF graph as the input "InputA"
    33. withInput("InputA", "input_a").
    34. // Column "input_b" gets sent to the TF graph as the input "InputB"
    35. withInput("InputB", "input_b").
    36. // TF graph output "MyResult" gets placed in the leap frame as col
    37. // "my_result"
    38. withOutput("MyResult", "my_result")
    39. // Create the MLeap transformer that executes the TF model against
    40. // A leap frame
    41. val transformer = TensorflowTransformer(shape = shape, model = model)
    42. // Create a sample leap frame to transform with the Tensorflow graph
    43. val schema = StructType(StructField("input_a", ScalarType.Float), StructField("input_b", ScalarType.Float)).get
    44. val dataset = Seq(Row(5.6f, 7.9f),
    45. Row(3.4f, 6.7f),
    46. Row(1.2f, 9.7f))
    47. val frame = DefaultLeapFrame(schema, dataset)
    48. // Transform the leap frame and make sure it behaves as expected
    49. val data = transformer.transform(frame).get.dataset
    50. assert(data(0)(2).asInstanceOf[Tensor[Float]].get(0).get == 5.6f * 7.9f)
    51. assert(data(1)(2).asInstanceOf[Tensor[Float]].get(0).get == 3.4f * 6.7f)
    52. assert(data(2)(2).asInstanceOf[Tensor[Float]].get(0).get == 1.2f * 9.7f)
    53. // Cleanup the transformer
    54. // This closes the TF session and graph resources
    55. transformer.close()

    更多关于 TensorFlow 集成如何运作的细节:

    1. 数据集成与转换的相关细节参见本章节。
    2. 序列化 TensorFlow Graph 为 MLeap Bundle 的相关细节参见本章节。