Python源码示例:tensorflow.core.protobuf.meta.TensorInfo()

示例1
def testSignatureDefValidation(self):
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_signature_def_validation")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    tensor_without_name = meta_graph_pb2.TensorInfo()
    tensor_without_name.dtype = types_pb2.DT_FLOAT
    self._validate_inputs_tensor_info(builder, tensor_without_name)
    self._validate_outputs_tensor_info(builder, tensor_without_name)

    tensor_without_dtype = meta_graph_pb2.TensorInfo()
    tensor_without_dtype.name = "x"
    self._validate_inputs_tensor_info(builder, tensor_without_dtype)
    self._validate_outputs_tensor_info(builder, tensor_without_dtype)

    tensor_empty = meta_graph_pb2.TensorInfo()
    self._validate_inputs_tensor_info(builder, tensor_empty)
    self._validate_outputs_tensor_info(builder, tensor_empty) 
示例2
def testConvertDefaultSignatureRegressionToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    regression_signature = manifest_pb2.RegressionSignature()
    regression_signature.input.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_INPUTS))
    regression_signature.output.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_OUTPUTS))
    signatures_proto.default_signature.regression_signature.CopyFrom(
        regression_signature)
    signature_def = bundle_shim._convert_default_signature_to_signature_def(
        signatures_proto)

    # Validate regression signature correctly copied over.
    self.assertEqual(signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS))
    self.assertProtoEquals(
        signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_OUTPUTS)) 
示例3
def get_node_wrapped_tensor_info(meta_graph_def: meta_graph_pb2.MetaGraphDef,
                                 path: Text) -> any_pb2.Any:
  """Get the Any-wrapped TensorInfo for the node from the meta_graph_def.

  Args:
     meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
       node name from.
     path: Name of the collection containing the node name.

  Returns:
    The Any-wrapped TensorInfo for the node retrieved from the CollectionDef.

  Raises:
    KeyError: There was no CollectionDef with the given name (path).
    ValueError: The any_list in the CollectionDef with the given name did
      not have length 1.
  """
  if path not in meta_graph_def.collection_def:
    raise KeyError('could not find path %s in collection defs. meta_graph_def '
                   'was %s' % (path, meta_graph_def))
  if len(meta_graph_def.collection_def[path].any_list.value) != 1:
    raise ValueError(
        'any_list should be of length 1. path was %s, any_list was: %s.' %
        (path, meta_graph_def.collection_def[path].any_list.value))
  return meta_graph_def.collection_def[path].any_list.value[0] 
示例4
def encode_tensor_node(node: types.TensorType) -> any_pb2.Any:
  """Encode a "reference" to a Tensor/SparseTensor as a TensorInfo in an Any.

  We put the Tensor / SparseTensor in a TensorInfo, which we then wrap in an
  Any so that it can be added to the CollectionDef.

  Args:
    node: Tensor node.

  Returns:
    Any proto wrapping a TensorInfo.
  """
  any_buf = any_pb2.Any()
  tensor_info = tf.compat.v1.saved_model.utils.build_tensor_info(node)
  any_buf.Pack(tensor_info)
  return any_buf 
示例5
def decode_tensor_node(graph: tf.Graph,
                       encoded_tensor_node: any_pb2.Any) -> types.TensorType:
  """Decode an encoded Tensor node encoded with encode_tensor_node.

  Decodes the encoded Tensor "reference", and returns the node in the given
  graph corresponding to that Tensor.

  Args:
    graph: Graph the Tensor
    encoded_tensor_node: Encoded Tensor.

  Returns:
    Decoded Tensor.
  """
  tensor_info = meta_graph_pb2.TensorInfo()
  encoded_tensor_node.Unpack(tensor_info)
  return tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info(
      tensor_info, graph) 
示例6
def build_tensor_info(tensor):
  """Utility function to build TensorInfo proto.

  Args:
    tensor: Tensor or SparseTensor whose name, dtype and shape are used to
        build the TensorInfo. For SparseTensors, the names of the three
        constitutent Tensors are used.

  Returns:
    A TensorInfo protocol buffer constructed based on the supplied argument.
  """
  tensor_info = meta_graph_pb2.TensorInfo(
      dtype=dtypes.as_dtype(tensor.dtype).as_datatype_enum,
      tensor_shape=tensor.get_shape().as_proto())
  if isinstance(tensor, sparse_tensor.SparseTensor):
    tensor_info.coo_sparse.values_tensor_name = tensor.values.name
    tensor_info.coo_sparse.indices_tensor_name = tensor.indices.name
    tensor_info.coo_sparse.dense_shape_tensor_name = tensor.dense_shape.name
  else:
    tensor_info.name = tensor.name
  return tensor_info 
示例7
def testSignatureDefValidation(self):
    export_dir = os.path.join(test.get_temp_dir(),
                              "test_signature_def_validation")
    builder = saved_model_builder.SavedModelBuilder(export_dir)

    tensor_without_name = meta_graph_pb2.TensorInfo()
    tensor_without_name.dtype = types_pb2.DT_FLOAT
    self._validate_inputs_tensor_info(builder, tensor_without_name)
    self._validate_outputs_tensor_info(builder, tensor_without_name)

    tensor_without_dtype = meta_graph_pb2.TensorInfo()
    tensor_without_dtype.name = "x"
    self._validate_inputs_tensor_info(builder, tensor_without_dtype)
    self._validate_outputs_tensor_info(builder, tensor_without_dtype)

    tensor_empty = meta_graph_pb2.TensorInfo()
    self._validate_inputs_tensor_info(builder, tensor_empty)
    self._validate_outputs_tensor_info(builder, tensor_empty) 
示例8
def testConvertDefaultSignatureRegressionToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    regression_signature = manifest_pb2.RegressionSignature()
    regression_signature.input.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_INPUTS))
    regression_signature.output.CopyFrom(
        manifest_pb2.TensorBinding(
            tensor_name=signature_constants.REGRESS_OUTPUTS))
    signatures_proto.default_signature.regression_signature.CopyFrom(
        regression_signature)
    signature_def = bundle_shim._convert_default_signature_to_signature_def(
        signatures_proto)

    # Validate regression signature correctly copied over.
    self.assertEqual(signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_INPUTS))
    self.assertProtoEquals(
        signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name=signature_constants.REGRESS_OUTPUTS)) 
示例9
def build_tensor_info(tensor):
  """Utility function to build TensorInfo proto.

  Args:
    tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.

  Returns:
    A TensorInfo protocol buffer constructed based on the supplied argument.
  """
  dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
  return meta_graph_pb2.TensorInfo(
      name=tensor.name,
      dtype=dtype_enum,
      tensor_shape=tensor.get_shape().as_proto()) 
示例10
def _add_input_to_signature_def(tensor_name, map_key, signature_def):
  """Add input tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def inputs
    map_key: string key to key into signature_def inputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def inputs map keyed with
    map_key
  """
  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.inputs[map_key].CopyFrom(tensor_info) 
示例11
def _add_output_to_signature_def(tensor_name, map_key, signature_def):
  """Add output tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def outputs
    map_key: string key to key into signature_def outputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def outputs map keyed with
    map_key
  """

  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.outputs[map_key].CopyFrom(tensor_info) 
示例12
def build_tensor_info(tensor):
  """Utility function to build TensorInfo proto.

  Args:
    tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.

  Returns:
    A TensorInfo protocol buffer constructed based on the supplied argument.
  """
  dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
  return meta_graph_pb2.TensorInfo(
      name=tensor.name,
      dtype=dtype_enum,
      tensor_shape=tensor.get_shape().as_proto()) 
示例13
def testConvertNamedSignatureToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["input_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="input"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom(
            generic_signature)

    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["output_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="output"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom(
            generic_signature)
    signature_def = bundle_shim._convert_named_signatures_to_signature_def(
        signatures_proto)
    self.assertEqual(signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs["input_key"],
        meta_graph_pb2.TensorInfo(name="input"))
    self.assertProtoEquals(
        signature_def.outputs["output_key"],
        meta_graph_pb2.TensorInfo(name="output")) 
示例14
def _add_input_to_signature_def(tensor_name, map_key, signature_def):
  """Add input tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def inputs
    map_key: string key to key into signature_def inputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def inputs map keyed with
    map_key
  """
  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.inputs[map_key].CopyFrom(tensor_info) 
示例15
def _add_output_to_signature_def(tensor_name, map_key, signature_def):
  """Add output tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def outputs
    map_key: string key to key into signature_def outputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def outputs map keyed with
    map_key
  """

  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.outputs[map_key].CopyFrom(tensor_info) 
示例16
def build_tensor_info(name=None, dtype=None, shape=None):
  """Utility function to build TensorInfo proto.

  Args:
    name: Name of the tensor to be used in the TensorInfo.
    dtype: Datatype to be set in the TensorInfo.
    shape: TensorShapeProto to specify the shape of the tensor in the
        TensorInfo.

  Returns:
    A TensorInfo protocol buffer constructed based on the supplied arguments.
  """
  return meta_graph_pb2.TensorInfo(name=name, dtype=dtype, shape=shape)

# SignatureDef helpers. 
示例17
def test_ragged_roundtrip(self):
    if not hasattr(meta_graph_pb2.TensorInfo, 'CompositeTensor'):
      self.skipTest('This version of TensorFlow does not support '
                    'CompositeTenors in TensorInfo.')
    export_path = os.path.join(tempfile.mkdtemp(), 'export')

    with tf.compat.v1.Graph().as_default():
      with tf.compat.v1.Session().as_default() as session:
        input_float = tf.compat.v1.ragged.placeholder(tf.float32, ragged_rank=1,
                                                      value_shape=[])
        output = input_float / 2.0
        inputs = {'input': input_float}
        outputs = {'output': output}
        saved_transform_io.write_saved_transform_from_session(
            session, inputs, outputs, export_path)

    with tf.compat.v1.Graph().as_default():
      with tf.compat.v1.Session().as_default() as session:
        splits = np.array([0, 2, 3], dtype=np.int64)
        values = np.array([1.0, 2.0, 4.0], dtype=np.float32)
        input_ragged = tf.RaggedTensor.from_row_splits(values, splits)

        # Using a computed input gives confidence that the graphs are fused
        inputs = {'input': input_ragged * 10}
        _, outputs = (
            saved_transform_io.partially_apply_saved_transform_internal(
                export_path, inputs))
        output_ragged = outputs['output']
        self.assertIsInstance(output_ragged, tf.RaggedTensor)
        result = session.run(output_ragged)

        # indices and shape unchanged; values multipled by 10 and divided by 2
        self.assertAllEqual(splits, result.row_splits)
        self.assertEqual([5.0, 10.0, 20.0], result.values.tolist()) 
示例18
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
  """Returns the Tensor or SparseTensor described by a TensorInfo proto.

  Args:
    tensor_info: A TensorInfo proto describing a Tensor or SparseTensor.
    graph: The tf.Graph in which tensors are looked up. If None, the
        current default graph is used.
    import_scope: If not None, names in `tensor_info` are prefixed with this
        string before lookup.

  Returns:
    The Tensor or SparseTensor in `graph` described by `tensor_info`.

  Raises:
    KeyError: If `tensor_info` does not correspond to a tensor in `graph`.
    ValueError: If `tensor_info` is malformed.
  """
  graph = graph if graph is not None else ops.get_default_graph()
  def _get_tensor(name):
    return graph.get_tensor_by_name(
        ops.prepend_name_scope(name, import_scope=import_scope))
  encoding = tensor_info.WhichOneof("encoding")
  if encoding == "name":
    return _get_tensor(tensor_info.name)
  elif encoding == "coo_sparse":
    return sparse_tensor.SparseTensor(
        _get_tensor(tensor_info.coo_sparse.indices_tensor_name),
        _get_tensor(tensor_info.coo_sparse.values_tensor_name),
        _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
  else:
    raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) 
示例19
def build_tensor_info(tensor):
  """Utility function to build TensorInfo proto.

  Args:
    tensor: Tensor whose name, dtype and shape are used to build the TensorInfo.

  Returns:
    A TensorInfo protocol buffer constructed based on the supplied argument.
  """
  dtype_enum = dtypes.as_dtype(tensor.dtype).as_datatype_enum
  return meta_graph_pb2.TensorInfo(
      name=tensor.name,
      dtype=dtype_enum,
      tensor_shape=tensor.get_shape().as_proto()) 
示例20
def testConvertNamedSignatureToSignatureDef(self):
    signatures_proto = manifest_pb2.Signatures()
    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["input_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="input"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_INPUTS].generic_signature.CopyFrom(
            generic_signature)

    generic_signature = manifest_pb2.GenericSignature()
    generic_signature.map["output_key"].CopyFrom(
        manifest_pb2.TensorBinding(tensor_name="output"))
    signatures_proto.named_signatures[
        signature_constants.PREDICT_OUTPUTS].generic_signature.CopyFrom(
            generic_signature)
    signature_def = bundle_shim._convert_named_signatures_to_signature_def(
        signatures_proto)
    self.assertEqual(signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertProtoEquals(
        signature_def.inputs["input_key"],
        meta_graph_pb2.TensorInfo(name="input"))
    self.assertProtoEquals(
        signature_def.outputs["output_key"],
        meta_graph_pb2.TensorInfo(name="output")) 
示例21
def _add_input_to_signature_def(tensor_name, map_key, signature_def):
  """Add input tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def inputs
    map_key: string key to key into signature_def inputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def inputs map keyed with
    map_key
  """
  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.inputs[map_key].CopyFrom(tensor_info) 
示例22
def _add_output_to_signature_def(tensor_name, map_key, signature_def):
  """Add output tensor to signature_def.

  Args:
    tensor_name: string name of tensor to add to signature_def outputs
    map_key: string key to key into signature_def outputs map
    signature_def: object of type  meta_graph_pb2.SignatureDef()

  Sideffect:
    adds a TensorInfo with tensor_name to signature_def outputs map keyed with
    map_key
  """

  tensor_info = meta_graph_pb2.TensorInfo(name=tensor_name)
  signature_def.outputs[map_key].CopyFrom(tensor_info) 
示例23
def build_prediction_graph(self):
    """Builds prediction graph and registers appropriate endpoints."""
    examples = tf.placeholder(tf.string, shape=(None,))
    features = {
        'image': tf.FixedLenFeature(
            shape=[IMAGE_PIXELS], dtype=tf.float32),
        'key': tf.FixedLenFeature(
            shape=[], dtype=tf.string),
    }

    parsed = tf.parse_example(examples, features)
    images = parsed['image']
    keys = parsed['key']

    # Build a Graph that computes predictions from the inference model.
    logits = inference(images, self.hidden1, self.hidden2)
    softmax = tf.nn.softmax(logits)
    prediction = tf.argmax(softmax, 1)

    # Mark the inputs and the outputs
    # Marking the input tensor with an alias with suffix _bytes. This is to
    # indicate that this tensor value is raw bytes and will be base64 encoded
    # over HTTP.
    # Note that any output tensor marked with an alias with suffix _bytes, shall
    # be base64 encoded in the HTTP response. To get the binary value, it
    # should be base64 decoded.
    input_signatures = {}
    predict_input_tensor = meta_graph_pb2.TensorInfo()
    predict_input_tensor.name = examples.name
    predict_input_tensor.dtype = examples.dtype.as_datatype_enum
    input_signatures['example_bytes'] = predict_input_tensor

    tf.add_to_collection('inputs',
                         json.dumps({
                             'examples_bytes': examples.name
                         }))
    tf.add_to_collection('outputs',
                         json.dumps({
                             'key': keys.name,
                             'prediction': prediction.name,
                             'scores': softmax.name
                         }))
    output_signatures = {}
    outputs_dict = {'key': keys.name,
                    'prediction': prediction.name,
                    'scores': softmax.name}
    for key, val in outputs_dict.iteritems():
      predict_output_tensor = meta_graph_pb2.TensorInfo()
      predict_output_tensor.name = val
      for placeholder in [keys, prediction, softmax]:
        if placeholder.name == val:
          predict_output_tensor.dtype = placeholder.dtype.as_datatype_enum
      output_signatures[key] = predict_output_tensor
    return input_signatures, output_signatures 
示例24
def testAddInputToSignatureDef(self):
    signature_def = meta_graph_pb2.SignatureDef()
    signature_def_compare = meta_graph_pb2.SignatureDef()

    # Add input to signature-def corresponding to `foo_key`.
    bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Attempt to add another input to the signature-def with the same tensor
    # name and key.
    bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Add another input to the signature-def corresponding to `bar-key`.
    bundle_shim._add_input_to_signature_def("bar-name", "bar-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 2)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["bar-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Add an input to the signature-def corresponding to `foo-key` with an
    # updated tensor name.
    bundle_shim._add_input_to_signature_def("bar-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 2)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Test that there are no other side-effects.
    del signature_def.inputs["foo-key"]
    del signature_def.inputs["bar-key"]
    self.assertProtoEquals(signature_def, signature_def_compare) 
示例25
def testAddOutputToSignatureDef(self):
    signature_def = meta_graph_pb2.SignatureDef()
    signature_def_compare = meta_graph_pb2.SignatureDef()

    # Add output to signature-def corresponding to `foo_key`.
    bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Attempt to add another output to the signature-def with the same tensor
    # name and key.
    bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Add another output to the signature-def corresponding to `bar-key`.
    bundle_shim._add_output_to_signature_def("bar-name", "bar-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 2)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["bar-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Add an output to the signature-def corresponding to `foo-key` with an
    # updated tensor name.
    bundle_shim._add_output_to_signature_def("bar-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 2)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Test that there are no other sideeffects.
    del signature_def.outputs["foo-key"]
    del signature_def.outputs["bar-key"]
    self.assertProtoEquals(signature_def, signature_def_compare) 
示例26
def testConvertSignaturesToSignatureDefs(self):
    base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
    meta_graph_filename = os.path.join(base_path,
                                       constants.META_GRAPH_DEF_FILENAME)
    metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(default_signature_def.inputs), 1)
    self.assertEqual(len(default_signature_def.outputs), 1)
    self.assertProtoEquals(
        default_signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name="tf_example:0"))
    self.assertProtoEquals(
        default_signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name="Identity:0"))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(named_signature_def.inputs), 1)
    self.assertEqual(len(named_signature_def.outputs), 1)
    self.assertProtoEquals(
        named_signature_def.inputs["x"], meta_graph_pb2.TensorInfo(name="x:0"))
    self.assertProtoEquals(
        named_signature_def.outputs["y"], meta_graph_pb2.TensorInfo(name="y:0"))

    # Now try default signature only
    collection_def = metagraph_def.collection_def
    signatures_proto = manifest_pb2.Signatures()
    signatures = collection_def[constants.SIGNATURES_KEY].any_list.value[0]
    signatures.Unpack(signatures_proto)
    named_only_signatures_proto = manifest_pb2.Signatures()
    named_only_signatures_proto.CopyFrom(signatures_proto)

    default_only_signatures_proto = manifest_pb2.Signatures()
    default_only_signatures_proto.CopyFrom(signatures_proto)
    default_only_signatures_proto.named_signatures.clear()
    default_only_signatures_proto.ClearField("named_signatures")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(default_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(named_signature_def, None)

    named_only_signatures_proto.ClearField("default_signature")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(named_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(default_signature_def, None) 
示例27
def testAddInputToSignatureDef(self):
    signature_def = meta_graph_pb2.SignatureDef()
    signature_def_compare = meta_graph_pb2.SignatureDef()

    # Add input to signature-def corresponding to `foo_key`.
    bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Attempt to add another input to the signature-def with the same tensor
    # name and key.
    bundle_shim._add_input_to_signature_def("foo-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 1)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Add another input to the signature-def corresponding to `bar-key`.
    bundle_shim._add_input_to_signature_def("bar-name", "bar-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 2)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["bar-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Add an input to the signature-def corresponding to `foo-key` with an
    # updated tensor name.
    bundle_shim._add_input_to_signature_def("bar-name", "foo-key",
                                            signature_def)
    self.assertEqual(len(signature_def.inputs), 2)
    self.assertEqual(len(signature_def.outputs), 0)
    self.assertProtoEquals(
        signature_def.inputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Test that there are no other side-effects.
    del signature_def.inputs["foo-key"]
    del signature_def.inputs["bar-key"]
    self.assertProtoEquals(signature_def, signature_def_compare) 
示例28
def testAddOutputToSignatureDef(self):
    signature_def = meta_graph_pb2.SignatureDef()
    signature_def_compare = meta_graph_pb2.SignatureDef()

    # Add output to signature-def corresponding to `foo_key`.
    bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Attempt to add another output to the signature-def with the same tensor
    # name and key.
    bundle_shim._add_output_to_signature_def("foo-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 1)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="foo-name"))

    # Add another output to the signature-def corresponding to `bar-key`.
    bundle_shim._add_output_to_signature_def("bar-name", "bar-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 2)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["bar-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Add an output to the signature-def corresponding to `foo-key` with an
    # updated tensor name.
    bundle_shim._add_output_to_signature_def("bar-name", "foo-key",
                                             signature_def)
    self.assertEqual(len(signature_def.outputs), 2)
    self.assertEqual(len(signature_def.inputs), 0)
    self.assertProtoEquals(
        signature_def.outputs["foo-key"],
        meta_graph_pb2.TensorInfo(name="bar-name"))

    # Test that there are no other sideeffects.
    del signature_def.outputs["foo-key"]
    del signature_def.outputs["bar-key"]
    self.assertProtoEquals(signature_def, signature_def_compare) 
示例29
def testConvertSignaturesToSignatureDefs(self):
    base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
    meta_graph_filename = os.path.join(base_path,
                                       constants.META_GRAPH_DEF_FILENAME)
    metagraph_def = meta_graph.read_meta_graph_file(meta_graph_filename)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(len(default_signature_def.inputs), 1)
    self.assertEqual(len(default_signature_def.outputs), 1)
    self.assertProtoEquals(
        default_signature_def.inputs[signature_constants.REGRESS_INPUTS],
        meta_graph_pb2.TensorInfo(name="tf_example:0"))
    self.assertProtoEquals(
        default_signature_def.outputs[signature_constants.REGRESS_OUTPUTS],
        meta_graph_pb2.TensorInfo(name="Identity:0"))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(len(named_signature_def.inputs), 1)
    self.assertEqual(len(named_signature_def.outputs), 1)
    self.assertProtoEquals(
        named_signature_def.inputs["x"], meta_graph_pb2.TensorInfo(name="x:0"))
    self.assertProtoEquals(
        named_signature_def.outputs["y"], meta_graph_pb2.TensorInfo(name="y:0"))

    # Now try default signature only
    collection_def = metagraph_def.collection_def
    signatures_proto = manifest_pb2.Signatures()
    signatures = collection_def[constants.SIGNATURES_KEY].any_list.value[0]
    signatures.Unpack(signatures_proto)
    named_only_signatures_proto = manifest_pb2.Signatures()
    named_only_signatures_proto.CopyFrom(signatures_proto)

    default_only_signatures_proto = manifest_pb2.Signatures()
    default_only_signatures_proto.CopyFrom(signatures_proto)
    default_only_signatures_proto.named_signatures.clear()
    default_only_signatures_proto.ClearField("named_signatures")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(default_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(default_signature_def.method_name,
                     signature_constants.REGRESS_METHOD_NAME)
    self.assertEqual(named_signature_def, None)

    named_only_signatures_proto.ClearField("default_signature")
    metagraph_def.collection_def[constants.SIGNATURES_KEY].any_list.value[
        0].Pack(named_only_signatures_proto)
    default_signature_def, named_signature_def = (
        bundle_shim._convert_signatures_to_signature_defs(metagraph_def))
    self.assertEqual(named_signature_def.method_name,
                     signature_constants.PREDICT_METHOD_NAME)
    self.assertEqual(default_signature_def, None)