Python源码示例:tensorflow.core.protobuf.meta.MetaGraphDef()
示例1
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
示例2
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
示例3
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
示例4
def testAddCollectionDefFails(self):
with self.test_session():
# Creates a graph.
v0 = tf.Variable(10.0, name="v0")
# Creates a saver.
save = tf.train.Saver({"v0": v0})
# Generates MetaGraphDef.
meta_graph_def = meta_graph_pb2.MetaGraphDef()
# Verifies that collection with unsupported key will not be added.
tf.add_to_collection(save, 3)
save._add_collection_def(meta_graph_def, save)
self.assertEqual(len(meta_graph_def.collection_def), 0)
# Verifies that collection where item type does not match expected
# type will not be added.
tf.add_to_collection("int_collection", 3)
tf.add_to_collection("int_collection", 3.5)
save._add_collection_def(meta_graph_def, "int_collection")
self.assertEqual(len(meta_graph_def.collection_def), 0)
示例5
def testSliceVariable(self):
test_dir = _TestDir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.test_session():
v1 = tf.Variable([20.0], name="v1")
v2 = tf.Variable([20.0], name="v2")
v2._set_save_slice_info(tf.Variable.SaveSliceInfo("v1", [1], [0], [1]))
# The names are different and will work.
slice_saver = tf.train.Saver({"first": v1, "second": v2})
tf.global_variables_initializer().run()
# Exports to meta_graph
meta_graph_def = slice_saver.export_meta_graph(filename)
with tf.Graph().as_default():
# Restores from MetaGraphDef.
new_saver = tf.train.import_meta_graph(filename)
self.assertIsNotNone(new_saver)
# Generates a new MetaGraphDef.
new_meta_graph_def = new_saver.export_meta_graph()
# It should be the same as the original.
self.assertProtoEquals(meta_graph_def, new_meta_graph_def)
示例6
def _load_meta(model_network_path):
"""Load a tensorflow meta file from disk
Parameters
----------
model_network_path: str
Path where the model network path is (protobuf meta file)
Returns
-------
model: A tensorflow protobuf file
"""
from tensorflow.core.protobuf import meta_graph_pb2
from mmdnn.conversion.common.IR.IR_graph import load_protobuf_from_file
meta_graph = meta_graph_pb2.MetaGraphDef()
load_protobuf_from_file(meta_graph, model_network_path)
graph = meta_graph.graph_def
print ("Tensorflow model file [%s] loaded successfully." % model_network_path)
return graph
示例7
def load_tfma_version(
signature_def: tf.compat.v1.MetaGraphDef.SignatureDefEntry,
graph: tf.Graph,
) -> types.TensorType:
"""Loads TFMA version information from signature_def.inputs.
Args:
signature_def: SignatureDef to lookup node in.
graph: TensorFlow graph to lookup the node in.
Returns:
TFMA version tensor.
Raises:
ValueError: If version not found signature_def.inputs.
"""
if constants.SIGNATURE_DEF_TFMA_VERSION_KEY not in signature_def.inputs:
raise ValueError('tfma version not found in signature_def: %s' %
signature_def)
return tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info(
signature_def.inputs[constants.SIGNATURE_DEF_TFMA_VERSION_KEY], graph)
示例8
def load_iterator_initializer_name(
signature_def: tf.compat.v1.MetaGraphDef.SignatureDefEntry,
graph: tf.Graph,
) -> Optional[types.TensorType]:
"""Loads iterator initializer name tensor from signature_def.inputs.
Args:
signature_def: SignatureDef to lookup initializer in.
graph: TensorFlow graph to lookup the initializer in.
Returns:
Tensor containing iterator initializer op name or None if not used.
"""
if constants.SIGNATURE_DEF_ITERATOR_INITIALIZER_KEY in signature_def.inputs:
return tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info(
signature_def.inputs[constants.SIGNATURE_DEF_ITERATOR_INITIALIZER_KEY],
graph)
return None
示例9
def load_predictions(signature_def: tf.compat.v1.MetaGraphDef.SignatureDefEntry,
graph: tf.Graph) -> Dict[Text, types.TensorType]:
"""Loads prediction nodes from signature_def.outputs.
Args:
signature_def: SignatureDef to lookup nodes in.
graph: TensorFlow graph to lookup the nodes in.
Returns:
Predictions map as an OrderedDict.
"""
# The canonical ordering we use here is simply the ordering we get
# from the predictions collection.
predictions = extract_signature_inputs_or_outputs_with_prefix(
constants.PREDICTIONS_NAME, signature_def.outputs,
util.default_dict_key(constants.PREDICTIONS_NAME))
predictions_map = collections.OrderedDict()
for k, v in predictions.items():
# Extract to dictionary with a single key for consistency with
# how features and labels are extracted.
predictions_map[
k] = tf.compat.v1.saved_model.utils.get_tensor_from_tensor_info(
v, graph)
return predictions_map
示例10
def get_node_wrapped_tensor_info(meta_graph_def: meta_graph_pb2.MetaGraphDef,
path: Text) -> any_pb2.Any:
"""Get the Any-wrapped TensorInfo for the node from the meta_graph_def.
Args:
meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
node name from.
path: Name of the collection containing the node name.
Returns:
The Any-wrapped TensorInfo for the node retrieved from the CollectionDef.
Raises:
KeyError: There was no CollectionDef with the given name (path).
ValueError: The any_list in the CollectionDef with the given name did
not have length 1.
"""
if path not in meta_graph_def.collection_def:
raise KeyError('could not find path %s in collection defs. meta_graph_def '
'was %s' % (path, meta_graph_def))
if len(meta_graph_def.collection_def[path].any_list.value) != 1:
raise ValueError(
'any_list should be of length 1. path was %s, any_list was: %s.' %
(path, meta_graph_def.collection_def[path].any_list.value))
return meta_graph_def.collection_def[path].any_list.value[0]
示例11
def get_node_in_graph(meta_graph_def: meta_graph_pb2.MetaGraphDef, path: Text,
graph: tf.Graph) -> types.TensorType:
"""Like get_node_wrapped_tensor_info, but looks up the node in the graph.
Args:
meta_graph_def: MetaGraphDef containing the CollectionDefs to extract the
node name from.
path: Name of the collection containing the node name.
graph: TensorFlow graph to lookup the nodes in.
Returns:
The node in the graph with the name returned by
get_node_wrapped_tensor_info.
"""
return encoding.decode_tensor_node(
graph, get_node_wrapped_tensor_info(meta_graph_def, path))
示例12
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
示例13
def add_meta_graph(self, meta_graph_def, global_step=None):
"""Adds a `MetaGraphDef` to the event file.
The `MetaGraphDef` allows running the given graph via
`saver.import_meta_graph()`.
Args:
meta_graph_def: A `MetaGraphDef` object, often as returned by
`saver.export_meta_graph()`.
global_step: Number. Optional global step counter to record with the
graph.
Raises:
TypeError: If both `meta_graph_def` is not an instance of `MetaGraphDef`.
"""
if not isinstance(meta_graph_def, meta_graph_pb2.MetaGraphDef):
raise TypeError("meta_graph_def must be type MetaGraphDef, saw type: %s" %
type(meta_graph_def))
meta_graph_bytes = meta_graph_def.SerializeToString()
event = event_pb2.Event(meta_graph_def=meta_graph_bytes)
self._add_event(event, global_step)
示例14
def _MetaGraphFilename(self, checkpoint_filename, meta_graph_suffix="meta"):
"""Returns the meta graph filename.
Args:
checkpoint_filename: Name of the checkpoint file.
meta_graph_suffix: Suffix for `MetaGraphDef` file. Defaults to 'meta'.
Returns:
MetaGraph file name.
"""
# If the checkpoint_filename is sharded, the checkpoint_filename could
# be of format model.ckpt-step#-?????-of-shard#. For example,
# model.ckpt-123456-?????-of-00005, or model.ckpt-123456-00001-of-00002.
basename = re.sub(r"-[\d\?]+-of-\d+$", "", checkpoint_filename)
meta_graph_filename = ".".join([basename, meta_graph_suffix])
return meta_graph_filename
示例15
def export_meta_graph(self,
filename=None,
collection_list=None,
as_text=False,
export_scope=None,
clear_devices=False):
"""Writes `MetaGraphDef` to save_path/filename.
Args:
filename: Optional meta_graph filename including the path.
collection_list: List of string keys to collect.
as_text: If `True`, writes the meta_graph as an ASCII proto.
export_scope: Optional `string`. Name scope to remove.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
Returns:
A `MetaGraphDef` proto.
"""
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
saver_def=self.saver_def,
collection_list=collection_list,
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices)
示例16
def _add_collection_def(meta_graph_def, key, export_scope=None):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
export_scope: Optional `string`. Name scope to remove.
"""
meta_graph.add_collection_def(meta_graph_def, key,
export_scope=export_scope)
示例17
def stripped_op_list_for_graph(graph_def):
"""Collect the stripped OpDefs for ops used by a graph.
This function computes the `stripped_op_list` field of `MetaGraphDef` and
similar protos. The result can be communicated from the producer to the
consumer, which can then use the C++ function
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
Args:
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
Returns:
An `OpList` of ops used by the graph.
Raises:
ValueError: If an unregistered op is used.
"""
# This is the Python equivalent of StrippedOpListForGraph in C++.
# Unfortunately, since the Python op registry can differ from that in C++, we
# can't remove the duplication using swig (at least naively).
# TODO(irving): Support taking graphs directly.
used_ops = ops_used_by_graph_def(graph_def)
# Verify that all used ops are registered.
registered_ops = op_def_registry.get_registered_ops()
# These internal ops used by functions are not registered, so we need to
# whitelist them. # TODO(irving): Do something better here.
op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
for op in used_ops:
if op not in registered_ops and op not in op_whitelist:
raise ValueError("Op %s is used by the graph, but is not registered" % op)
# Build the stripped op list in sorted order
return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
if op in registered_ops])
示例18
def read_meta_graph_file(filename):
"""Reads a file containing `MetaGraphDef` and returns the protocol buffer.
Args:
filename: `meta_graph_def` filename including the path.
Returns:
A `MetaGraphDef` protocol buffer.
Raises:
IOError: If the file doesn't exist, or cannot be successfully parsed.
"""
meta_graph_def = meta_graph_pb2.MetaGraphDef()
if not file_io.file_exists(filename):
raise IOError("File %s does not exist." % filename)
# First try to read it as a binary file.
file_content = file_io.FileIO(filename, "rb").read()
try:
meta_graph_def.ParseFromString(file_content)
return meta_graph_def
except Exception: # pylint: disable=broad-except
pass
# Next try to read it as a text file.
try:
text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
return meta_graph_def
示例19
def MetaGraph(self):
"""Return the metagraph definition, if there is one.
Raises:
ValueError: If there is no metagraph for this run.
Returns:
The `meta_graph_def` proto.
"""
if self._meta_graph is None:
raise ValueError('There is no metagraph in this EventAccumulator')
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph
示例20
def _convert_signatures_to_signature_defs(metagraph_def):
"""Produce default and named upconverted SignatureDef objects from Signatures.
Args:
metagraph_def: object of type meta_graph_pb2.MetaGraphDef containing legacy
format Session Bundle signatures
Returns:
default_signature_def: object of type SignatureDef which contains an
upconverted version of default signatures in metagraph_def
named_signature_def: object of type SignatureDef which contains an
upconverted version of named signatures in metagraph_def
"""
collection_def = metagraph_def.collection_def
signatures_proto = manifest_pb2.Signatures()
signatures = collection_def[legacy_constants.SIGNATURES_KEY].any_list.value[0]
signatures.Unpack(signatures_proto)
default_signature_def = None
named_signature_def = None
if signatures_proto.HasField("default_signature"):
default_signature_def = _convert_default_signature_to_signature_def(
signatures_proto)
if len(signatures_proto.named_signatures) > 1:
named_signature_def = _convert_named_signatures_to_signature_def(
signatures_proto)
return default_signature_def, named_signature_def
示例21
def export_meta_graph(self,
filename=None,
collection_list=None,
as_text=False,
export_scope=None,
clear_devices=False):
"""Writes `MetaGraphDef` to save_path/filename.
Args:
filename: Optional meta_graph filename including the path.
collection_list: List of string keys to collect.
as_text: If `True`, writes the meta_graph as an ASCII proto.
export_scope: Optional `string`. Name scope to remove.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
Returns:
A `MetaGraphDef` proto.
"""
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
saver_def=self.saver_def,
collection_list=collection_list,
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices)
示例22
def _add_collection_def(meta_graph_def, key, export_scope=None):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
export_scope: Optional `string`. Name scope to remove.
"""
meta_graph.add_collection_def(meta_graph_def, key,
export_scope=export_scope)
示例23
def MetaGraph(self):
"""Return the metagraph definition, if there is one.
Raises:
ValueError: If there is no metagraph for this run.
Returns:
The `meta_graph_def` proto.
"""
if self._meta_graph is None:
raise ValueError('There is no metagraph in this EventAccumulator')
meta_graph = meta_graph_pb2.MetaGraphDef()
meta_graph.ParseFromString(self._meta_graph)
return meta_graph
示例24
def stripped_op_list_for_graph(graph_def):
"""Collect the stripped OpDefs for ops used by a graph.
This function computes the `stripped_op_list` field of `MetaGraphDef` and
similar protos. The result can be communicated from the producer to the
consumer, which can then use the C++ function
`RemoveNewDefaultAttrsFromGraphDef` to improve forwards compatibility.
Args:
graph_def: A `GraphDef` proto, as from `graph.as_graph_def()`.
Returns:
An `OpList` of ops used by the graph.
Raises:
ValueError: If an unregistered op is used.
"""
# This is the Python equivalent of StrippedOpListForGraph in C++.
# Unfortunately, since the Python op registry can differ from that in C++, we
# can't remove the duplication using swig (at least naively).
# TODO(irving): Support taking graphs directly.
used_ops = ops_used_by_graph_def(graph_def)
# Verify that all used ops are registered.
registered_ops = op_def_registry.get_registered_ops()
# These internal ops used by functions are not registered, so we need to
# whitelist them. # TODO(irving): Do something better here.
op_whitelist = ("_Arg", "_Retval", "_ListToArray", "_ArrayToList")
for op in used_ops:
if op not in registered_ops and op not in op_whitelist:
raise ValueError("Op %s is used by the graph, but is not registered" % op)
# Build the stripped op list in sorted order
return op_def_pb2.OpList(op=[registered_ops[op] for op in sorted(used_ops)
if op in registered_ops])
示例25
def read_meta_graph_file(filename):
"""Reads a file containing `MetaGraphDef` and returns the protocol buffer.
Args:
filename: `meta_graph_def` filename including the path.
Returns:
A `MetaGraphDef` protocol buffer.
Raises:
IOError: If the file doesn't exist, or cannot be successfully parsed.
"""
meta_graph_def = meta_graph_pb2.MetaGraphDef()
if not file_io.file_exists(filename):
raise IOError("File %s does not exist." % filename)
# First try to read it as a binary file.
file_content = file_io.read_file_to_string(filename)
try:
meta_graph_def.ParseFromString(file_content)
return meta_graph_def
except Exception: # pylint: disable=broad-except
pass
# Next try to read it as a text file.
try:
text_format.Merge(file_content.decode("utf-8"), meta_graph_def)
except text_format.ParseError as e:
raise IOError("Cannot parse file %s: %s." % (filename, str(e)))
return meta_graph_def
示例26
def _convert_signatures_to_signature_defs(metagraph_def):
"""Produce default and named upconverted SignatureDef objects from Signatures.
Args:
metagraph_def: object of type meta_graph_pb2.MetaGraphDef containing legacy
format Session Bundle signatures
Returns:
default_signature_def: object of type SignatureDef which contains an
upconverted version of default signatures in metagraph_def
named_signature_def: object of type SignatureDef which contains an
upconverted version of named signatures in metagraph_def
"""
collection_def = metagraph_def.collection_def
signatures_proto = manifest_pb2.Signatures()
signatures = collection_def[legacy_constants.SIGNATURES_KEY].any_list.value[0]
signatures.Unpack(signatures_proto)
default_signature_def = None
named_signature_def = None
if signatures_proto.HasField("default_signature"):
default_signature_def = _convert_default_signature_to_signature_def(
signatures_proto)
if len(signatures_proto.named_signatures) > 1:
named_signature_def = _convert_named_signatures_to_signature_def(
signatures_proto)
return default_signature_def, named_signature_def
示例27
def _get_multi_gpu_meta_graph(single_gpu_meta_graph_def, op_names_to_replicate,
op_names_to_share, num_replicas,
tensor_or_op_name_to_replica_names):
multi_gpu_graph_def = \
construct_multi_gpu_graph_def(
single_gpu_meta_graph_def.graph_def,
op_names_to_replicate,
op_names_to_share,
num_replicas,
tensor_or_op_name_to_replica_names)
multi_gpu_meta_graph_def = meta_graph_pb2.MetaGraphDef()
multi_gpu_meta_graph_def.CopyFrom(single_gpu_meta_graph_def)
multi_gpu_meta_graph_def.graph_def.Clear()
multi_gpu_meta_graph_def.graph_def.CopyFrom(multi_gpu_graph_def)
return multi_gpu_meta_graph_def
示例28
def _get_multi_gpu_meta_graph(single_gpu_meta_graph_def, op_names_to_replicate,
op_names_to_share, num_replicas,
tensor_or_op_name_to_replica_names):
multi_gpu_graph_def = \
construct_multi_gpu_graph_def(
single_gpu_meta_graph_def.graph_def,
op_names_to_replicate,
op_names_to_share,
num_replicas,
tensor_or_op_name_to_replica_names)
multi_gpu_meta_graph_def = meta_graph_pb2.MetaGraphDef()
multi_gpu_meta_graph_def.CopyFrom(single_gpu_meta_graph_def)
multi_gpu_meta_graph_def.graph_def.Clear()
multi_gpu_meta_graph_def.graph_def.CopyFrom(multi_gpu_graph_def)
return multi_gpu_meta_graph_def
示例29
def export_meta_graph(self,
filename=None,
collection_list=None,
as_text=False,
export_scope=None,
clear_devices=False):
"""Writes `MetaGraphDef` to save_path/filename.
Args:
filename: Optional meta_graph filename including the path.
collection_list: List of string keys to collect.
as_text: If `True`, writes the meta_graph as an ASCII proto.
export_scope: Optional `string`. Name scope to remove.
clear_devices: Whether or not to clear the device field for an `Operation`
or `Tensor` during export.
Returns:
A `MetaGraphDef` proto.
"""
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
saver_def=self.saver_def,
collection_list=collection_list,
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices)
示例30
def _add_collection_def(meta_graph_def, key, export_scope=None):
"""Adds a collection to MetaGraphDef protocol buffer.
Args:
meta_graph_def: MetaGraphDef protocol buffer.
key: One of the GraphKeys or user-defined string.
export_scope: Optional `string`. Name scope to remove.
"""
meta_graph.add_collection_def(meta_graph_def, key,
export_scope=export_scope)