提问者:小点点

在TensorFlow中创建图形时确定张量形状


我正在尝试编写一块可重用的代码,它读取一个张量的形状,然后使用生成的对象定义其他张量的形状。我可以选择使用tf.shape(tensor)读取张量的动态形状,或者使用tensor.get_shape()读取张量的静态形状。玩具示例如下(使用两种不同的策略):

def my_function_strategy_1(x, y):
    x_shape = tf.shape(x)
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

def my_function_strategy_2(x, y):
    x_shape = x.get_shape()
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

我想在不同的图中使用这段代码。有时输入张量的形状是已知的,有时是未知的:

graph_A = tf.Graph()
with graph_A.as_default():
    x = tf.placeholder(tf.float32, [2, 4])
    y = tf.placeholder(tf.float32, [8])
    a, b, c, d = my_function(x, y)

with graph_B.as_default():
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)

我想要的行为是:(A)当输入张量的形状已知时(如< code>graph_A),我想要TensorFlow在图形创建时计算图形中的所有形状(这样它可以有效地分配资源,等等。),以及(B)当输入张量的形状未知时(如< code>graph_B),我希望张量流等到运行时计算图形中的所有形状。

函数的strategy_1版本几乎做到了这一点。它实现了(B),但并不完全实现(A),因为TensorFlow留下了一些张量的形状未知。例如,在上面的玩具示例中,abc的形状是在图形创建时计算的,但d的形状是未知的(即使d使用了非常相似的操作)。您可以通过打印a.get_shape()b.get_shape()等来检查这一点。

相反,strategy_2版本的函数对于图中的所有张量实现了(A),但没有实现(B),因为TensorFlow(可以理解)在尝试使用输入张量的(未知)静态形状来塑造其他张量时抛出了一个异常。

有没有办法在单个函数中同时实现(A)和(B)?strategy_1版本如何/为什么适用于图中的大多数张量,但不是全部?


共1个答案

匿名用户

您可以仔细选择形状的元素,以获得“两全其美”的结果:

def my_get_shape(tensor):
    if tensor.shape.ndims is None:
        # Fully dynamic
        return tf.shape(tensor)
    if tensor.shape.is_fully_defined():
        # Fully static
        return tensor.shape
    # Partially static
    dyn_shape = tf.shape(tensor)
    shape = []
    for i, d in enumerate(tensor.shape):
        shape.append(d.value if d.value is not None else dyn_shape[i])
    return shape

def my_function(x, y):
    x_shape = my_get_shape(x)  # Or just tf.shape(x)! - see edit
    a = tf.reshape(y, x_shape)
    b = tf.zeros(x_shape)
    num_x_values = x_shape[0]
    c = tf.reshape(y, [num_x_values, 4])
    d = tf.zeros([num_x_values, 4])
    return a, b, c, d

# Fully static
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [2, 4])
    y = tf.placeholder(tf.float32, [8])
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: (2, 4) , b: (2, 4) , c: (2, 4) , d: (2, 4)

# Fully dynamic
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32)
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: <unknown> , b: <unknown> , c: (?, 4) , d: (?, 4)

# Partially static
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [None, 4])
    y = tf.placeholder(tf.float32)
    a, b, c, d = my_function(x, y)
print('a:', a.shape, ', b:', b.shape, ', c:', c.shape, ', d:', d.shape)
# a: (?, 4) , b: (?, 4) , c: (?, 4) , d: (?, 4)

编辑:

实际上,在上一个片段中用 tf.shape 替换my_get_shape的工作方式完全相同。似乎 tf.shape 应该是默认值(注意不要用它塞满图形),除非您明确希望保持未定义维度。

我做了一些调查,但我无法完全解决整个问题。我不知道这是否有用,但我发现了一些事情。显然,TensorFlow在C级(它以前似乎在Python中使用过,但现在不再使用了)有一种“形状推断”机制。例如,如果您在tensorflow/core/ops/array_ops.cc中查看,您将看到每个操作声明的末尾都包含一个.SetShapeFn,这是一个使用InferenceContext尝试猜测操作输出形状的函数。这个类可以检查张量中的值是否已知,例如,当给定的张量是静态的时,对于tf.shape,或者对于tf.full(以及相关的,如tf.ones),已知值是正确的。形状推断算法的分辨率是Python中设置为张量形状的分辨率,它可以通过call_cpp_shape_fn直接调用(尽管我不知道它如何有用):

from tensorflow.python.framework.common_shapes import call_cpp_shape_fn
with tf.Graph().as_default():
    print(call_cpp_shape_fn(tf.reshape(tf.placeholder(tf.float32), tf.fill([2], 3)).op))
    # Shows this:
    # {
    #   'shapes': [dim { size: 3 } dim { size: 3 }],
    #   'handle_data': [None],
    #   'inputs_needed': b'\x12\x01\x01'
    # }
    print(call_cpp_shape_fn(tf.reshape(tf.placeholder(tf.float32), (2 * tf.fill([2], 3))).op))
    # Shows this:
    # {
    #   'shapes': [dim { size: -1 } dim { size: -1 }],
    #   'handle_data': [None],
    #   'inputs_needed': b'\x12\x01\x01'
    # }

你可以看到,虽然 tf.fill([2], 3) 被正确检查,但 TensorFlow 并没有发现 2 * tf.fill([2], 3) 是 [6, 6],大概是因为静态跟踪乘法等运算,即使操作数是已知的常量,也被认为太昂贵了。

我尚未发现的是,操作在何处声明它们的值可以是静态已知的,或者在何处/如何准确地检索这些值。例如,对于tf.shape,它似乎能够专门选择已知值,并将其余值保留为未定义。