Tensorflow Keras:评估时如何在自定义层中设置断点(调试)?


问题内容

我只想在自定义层中进行一些数值验证。

假设我们有一个非常简单的自定义层:

class test_layer(keras.layers.Layer):
    def __init__(self, **kwargs):
        super(test_layer, self).__init__(**kwargs)

    def build(self, input_shape):
        self.w = K.variable(1.)
        self._trainable_weights.append(self.w)
        super(test_layer, self).build(input_shape)

    def call(self, x, **kwargs):
        m = x * x            # Set break point here
        n = self.w * K.sqrt(x)
        return m + n

和主程序:

import tensorflow as tf
import keras
import keras.backend as K

input = keras.layers.Input((100,1))
y = test_layer()(input)

model = keras.Model(input,y)
model.predict(np.ones((100,1)))

如果在该行上设置断点调试,则m = x * x执行时程序将在此处暂停y = test_layer()(input),这是因为生成了图形,因此call()调用了该方法。

但是当我使用model.predict()它来赋予它真正的价值,并且想要在图层内部查看是否工作正常时,它不会停在一行m = x * x

我的问题是:

  1. call()计算图形正在兴建时,方法只叫什么名字?(提供实际价值时不会调用它吗?)

  2. 给实数输入时,如何调试(或在何处插入断点)以查看变量的值?


问题答案:
  1. 是。该call()方法仅用于构建计算图。

  2. 至于调试。我更喜欢使用TFDBG,这是针对tensorflow的推荐调试工具,尽管它不提供断点功能。

对于Keras,您可以将以下行添加到脚本中以使用TFDBG

import tf.keras.backend as K
from tensorflow.python import debug as tf_debug
sess = K.get_session()
sess = tf_debug.LocalCLIDebugWrapperSession(sess)
K.set_session(sess)