Python源码示例:tensorflow.contrib.tpu.python.tpu.tpu.TPUEstimatorSpec()

示例1
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
    mtf_samples = mtf.anonymize(self.sample(features, mesh))
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    if self.has_input:
      ndims = len(outputs.shape.as_list())
      actual_batch_size = tf.shape(features["inputs"])[0]
      outputs = tf.slice(
          outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1))
    predictions = {
        "outputs": outputs
    }
    if features.get("infer_targets") is not None:
      predictions["infer_targets"] = features["infer_targets"]

    if features.get("inputs") is not None:
      predictions["inputs"] = features["inputs"]

    if use_tpu:
      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)]) 
示例2
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
    mtf_samples = mtf.anonymize(self.sample(features, mesh))
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    if self.has_input:
      ndims = len(outputs.shape.as_list())
      actual_batch_size = tf.shape(features["inputs"])[0]
      outputs = tf.slice(
          outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1))
    predictions = {
        "outputs": outputs
    }
    if features.get("infer_targets") is not None:
      predictions["infer_targets"] = features["infer_targets"]

    if features.get("inputs") is not None:
      predictions["inputs"] = features["inputs"]

    if use_tpu:
      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)]) 
示例3
def estimator_spec_eval(
      self, features, logits, labels, loss, restore_hook, use_tpu):
    """Construct EstimatorSpec for EVAL mode."""
    hparams = self.hparams
    problem = hparams.problem
    if logits.get_shape().ndims == 3:
      logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)
    eval_metrics_fns = metrics.create_evaluation_metrics([problem], hparams)

    if use_tpu:
      def metric_fn(tf_logits, labels):
        with tf.device("cpu:0"), mtf.utils.outside_all_rewrites():
          eval_metrics = {}
          for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
            if metric_name.split("/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST:
              eval_metrics[metric_name] = metric_fn(
                  tf_logits, None, tf.identity(labels))
          return eval_metrics
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=loss,
          eval_metrics=(metric_fn, [logits, labels]))
    else:
      eval_metrics = {}
      predictions = {"predictions": logits}
      for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
        eval_metrics[metric_name] = metric_fn(logits, features,
                                              features["targets"])

      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          predictions=predictions,
          eval_metric_ops=eval_metrics,
          evaluation_hooks=[restore_hook],
          loss=loss) 
示例4
def estimator_spec_predict(self, features, mesh, mesh_impl, use_tpu):
    mtf_samples = mtf.anonymize(self.sample(features, mesh))
    lowering = mtf.Lowering(mesh.graph, {mesh: mesh_impl})
    outputs = lowering.export_to_tf_tensor(mtf_samples)
    if self.has_input:
      ndims = len(outputs.shape.as_list())
      actual_batch_size = tf.shape(features["inputs"])[0]
      outputs = tf.slice(
          outputs, [0] * ndims, [actual_batch_size] + [-1] * (ndims - 1))
    predictions = {
        "outputs": outputs
    }
    if features.get("infer_targets") is not None:
      predictions["infer_targets"] = features["infer_targets"]

    if features.get("inputs") is not None:
      predictions["inputs"] = features["inputs"]

    if use_tpu:
      t2t_model.remove_summaries()
      return tpu_estimator.TPUEstimatorSpec(
          mode=tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)])
    else:
      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.PREDICT,
          predictions=predictions,
          prediction_hooks=[mtf.MtfRestoreHook(lowering)]) 
示例5
def estimator_spec_eval(
      self, features, logits, labels, loss, restore_hook, use_tpu):
    """Construct EstimatorSpec for EVAL mode."""
    hparams = self.hparams
    problem = hparams.problem
    if logits.get_shape().ndims == 3:
      logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)

    # Support for multiproblem
    task_list = [problem]
    if hasattr(problem, "task_list"):
      task_list = problem.task_list

    eval_metrics_fns = metrics.create_evaluation_metrics(task_list, hparams)

    if use_tpu:
      def metric_fn(tf_logits, labels):
        with tf.device("cpu:0"), mtf.utils.outside_all_rewrites():
          eval_metrics = {}
          for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
            if metric_name.split("/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST:
              eval_metrics[metric_name] = metric_fn(
                  tf_logits, None, tf.identity(labels))
          return eval_metrics
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=loss,
          eval_metrics=(metric_fn, [logits, labels]))
    else:
      eval_metrics = {}
      predictions = {"predictions": logits}
      for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
        eval_metrics[metric_name] = metric_fn(logits, features,
                                              features["targets"])

      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          predictions=predictions,
          eval_metric_ops=eval_metrics,
          evaluation_hooks=[restore_hook],
          loss=loss) 
示例6
def estimator_spec_eval(
      self, features, logits, labels, loss, restore_hook, use_tpu):
    """Construct EstimatorSpec for EVAL mode."""
    hparams = self.hparams
    problem = hparams.problem
    if logits.get_shape().ndims == 3:
      logits = tf.expand_dims(tf.expand_dims(logits, 2), 3)

    # Support for multiproblem
    task_list = [problem]
    if hasattr(problem, "task_list"):
      task_list = problem.task_list

    eval_metrics_fns = metrics.create_evaluation_metrics(task_list, hparams)

    if use_tpu:
      def metric_fn(tf_logits, labels):
        with tf.device("cpu:0"), mtf.utils.outside_all_rewrites():
          eval_metrics = {}
          for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
            if metric_name.split("/")[-1] not in t2t_model.TPU_METRIC_BLACKLIST:
              eval_metrics[metric_name] = metric_fn(
                  tf_logits, None, tf.identity(labels))
          return eval_metrics
      return tpu_estimator.TPUEstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          evaluation_hooks=[restore_hook],
          loss=loss,
          eval_metrics=(metric_fn, [logits, labels]))
    else:
      eval_metrics = {}
      predictions = {"predictions": logits}
      for metric_name, metric_fn in six.iteritems(eval_metrics_fns):
        eval_metrics[metric_name] = metric_fn(logits, features,
                                              features["targets"])

      return tf.estimator.EstimatorSpec(
          tf.estimator.ModeKeys.EVAL,
          predictions=predictions,
          eval_metric_ops=eval_metrics,
          evaluation_hooks=[restore_hook],
          loss=loss)