提供给`tf.data.Dataset.from_generator(…)`的map函数可以解析张量对象吗?


问题内容

我想创建一个tf.data.Dataset.from_generator(...)数据集。我需要传递一个Python generator

我想像这样将先前数据集的属性传递给生成器:

dataset = dataset.interleave(
  map_func=lambda x: tf.data.Dataset.from_generator(generator=lambda: gen(x), output_types=tf.int64),
  cycle_length=2
)

我定义gen(...)要取值的位置(它是指向某些数据的指针,例如gen知道如何访问的文件名)。

失败是因为gen接收张量对象,而不是python / numpy值。

有没有办法将张量对象解析为内部的值gen(...)

交错生成器的原因是,这样我就可以使用其他数据集操作来操纵数据指针/文件名列表,.shuffle().repeat()无需将其烘焙到gen(...)函数中,如果我直接从列表中启动生成器,则很有必要数据指针/文件名。

我想使用生成器,因为每个数据指针/文件名将生成大量数据值。


问题答案:

TensorFlow现在支持将张量参数传递给生成器:

def map_func(tensor):
    dataset = tf.data.Dataset.from_generator(generator, tf.float32, args=(tensor,))
    return dataset