tf.assign到可变切片在tf.while_loop中不起作用


问题内容

以下代码有什么问题?如果将tf.assignop应用于tf.Variable循环的外部,则将其应用于a的一部分时效果很好。但是,在这种情况下,它给出以下错误。

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i, a):
    return i < n

def body(i, a):
    tf.assign(a[i], a[i-1] + a[i-2])
    return i + 1, a

i, b = tf.while_loop(cond, body, [2, a])

结果是:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 3210, in while_loop
    result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2942, in BuildLoop
    pred, body, original_loop_vars, loop_vars, shape_invariants)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/control_flow_ops.py", line 2879, in _BuildLoop
    body_result = body(*packed_vars_for_body)
  File "/home/hrbigelow/ai/lb-wavenet/while_var_test.py", line 11, in body
    tf.assign(a[i], a[i-1] + a[i-2])
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/state_ops.py", line 220, in assign
    return ref.assign(value, name=name)
  File "/home/hrbigelow/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/array_ops.py", line 697, in assign
    raise ValueError("Sliced assignment is only supported for variables")
ValueError: Sliced assignment is only supported for variables

问题答案:

您的变量不是循环内运行的操作的输出,它是一个位于循环外的外部实体。因此,您不必提供它作为参数。

另外,您需要强制执行更新,例如tf.control_dependencies在中使用body

import tensorflow as tf

v = [1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]
n = len(v)
a = tf.Variable(v, name = 'a')

def cond(i):
    return i < n

def body(i):
    op = tf.assign(a[i], a[i-1] + a[i-2])
    with tf.control_dependencies([op]):
      return i + 1

i = tf.while_loop(cond, body, [2])

sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
i.eval()
print(a.eval())
# [ 1  1  2  3  5  8 13 21 34 55 89]

可能您可能需要谨慎并设置parallel_iterations=1为强制循环按顺序运行。