• tf.TensorArray :TensorFlow 动态数组 *
  • TODO

    tf.TensorArray :TensorFlow 动态数组 *

    在部分网络结构,尤其是涉及到时间序列的结构中,我们可能需要将一系列张量以数组的方式依次存放起来,以供进一步处理。当然,在Eager Execution下,你可以直接使用一个Python列表(List)存放数组。不过,如果你需要基于计算图的特性(例如使用 @tf.function 加速模型运行或者使用SavedModel导出模型),就无法使用这种方式了。因此,TensorFlow提供了 tf.TensorArray ,一种支持计算图特性的TensorFlow动态数组。

    由于需要支持计算图, tf.TensorArray 的使用方式和一般编程语言中的列表/数组类型略有不同,包括4个方法:

    • TODO

    一个简单的示例如下:

    1. import tensorflow as tf
    2.  
    3. @tf.function
    4. def array_write_and_read():
    5. arr = tf.TensorArray(dtype=tf.float32, size=3)
    6. arr = arr.write(0, tf.constant(0.0))
    7. arr = arr.write(1, tf.constant(1.0))
    8. arr = arr.write(2, tf.constant(2.0))
    9. arr_0 = arr.read(0)
    10. arr_1 = arr.read(1)
    11. arr_2 = arr.read(2)
    12. return arr_0, arr_1, arr_2
    13.  
    14. a, b, c = array_write_and_read()
    15. print(a, b, c)

    输出:

    1. tf.Tensor(0.0, shape=(), dtype=float32) tf.Tensor(1.0, shape=(), dtype=float32) tf.Tensor(2.0, shape=(), dtype=float32)