Python源码示例:tensorflow.contrib.learn.python.learn.estimators.run.RunConfig()

示例1
def _get_default_schedule(config):
  """Returns the default schedule for the provided RunConfig."""
  if not config or not _is_distributed(config):
    return 'train_and_evaluate'

  if not config.task_type:
    raise ValueError('Must specify a schedule')

  if config.task_type == run_config_lib.TaskType.MASTER:
    # TODO(rhaertel): handle the case where there is more than one master
    # or explicitly disallow such a case.
    return 'train_and_evaluate'
  elif config.task_type == run_config_lib.TaskType.PS:
    return 'run_std_server'
  elif config.task_type == run_config_lib.TaskType.WORKER:
    return 'train'

  raise ValueError('No default schedule for task type: %s' % (config.task_type)) 
示例2
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
示例3
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
示例4
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.job_name:
    worker_device = '/job:%s/task:%d' % (config.job_name, config.task)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=False, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
示例5
def _get_replica_device_setter(config):
  """Creates a replica device setter if required.

  Args:
    config: A RunConfig instance.

  Returns:
    A replica device setter, or None.
  """
  ps_ops = [
      'Variable', 'VariableV2', 'AutoReloadVariable', 'MutableHashTable',
      'MutableHashTableOfTensors', 'MutableDenseHashTable'
  ]

  if config.task_type:
    worker_device = '/job:%s/task:%d' % (config.task_type, config.task_id)
  else:
    worker_device = '/job:worker'

  if config.num_ps_replicas > 0:
    return device_setter.replica_device_setter(
        ps_tasks=config.num_ps_replicas, worker_device=worker_device,
        merge_devices=True, ps_ops=ps_ops, cluster=config.cluster_spec)
  else:
    return None 
示例6
def config(self):
    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
    return copy.deepcopy(self._config) 
示例7
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      config: A RunConfig instance.
    """
    # Model directory.
    self._model_dir = model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)

    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None 
示例8
def config(self):
    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
    return copy.deepcopy(self._config) 
示例9
def _create_experiment_fn(output_dir):  # pylint: disable=unused-argument
  """Experiment creation function."""
  (columns, label_column, wide_columns, deep_columns, categorical_columns,
   continuous_columns) = census_model_config()

  census_data_source = CensusDataSource(FLAGS.data_dir,
                                        TRAIN_DATA_URL, TEST_DATA_URL,
                                        columns, label_column,
                                        categorical_columns,
                                        continuous_columns)

  config = run_config.RunConfig(master=FLAGS.master_grpc_url,
                                num_ps_replicas=FLAGS.num_parameter_servers,
                                task=FLAGS.worker_index)

  estimator = tf.contrib.learn.DNNLinearCombinedClassifier(
      model_dir=FLAGS.model_dir,
      linear_feature_columns=wide_columns,
      dnn_feature_columns=deep_columns,
      dnn_hidden_units=[5],
      config=config)

  return tf.contrib.learn.Experiment(
      estimator=estimator,
      train_input_fn=census_data_source.input_train_fn,
      eval_input_fn=census_data_source.input_test_fn,
      train_steps=FLAGS.train_steps,
      eval_steps=FLAGS.eval_steps
  ) 
示例10
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      config: A RunConfig instance.
    """
    # Model directory.
    self._model_dir = model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)

    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None 
示例11
def config(self):
    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
    return copy.deepcopy(self._config) 
示例12
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model.
      config: A RunConfig instance.
    """
    # Model directory.
    self._model_dir = model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)

    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None 
示例13
def config(self):
    # TODO(wicke): make RunConfig immutable, and then return it without a copy.
    return copy.deepcopy(self._config) 
示例14
def _wrapped_experiment_fn_with_uid_check(experiment_fn, require_hparams=False):
  """Wraps the `RunConfig` uid check with `experiment_fn`.

  For `experiment_fn` which takes `run_config`, it is expected that the
  `run_config` is passed to the Estimator correctly. Toward that, the wrapped
  `experiment_fn` compares the `uid` of the `RunConfig` instance.

  Args:
    experiment_fn: The original `experiment_fn` which takes `run_config` and
      `hparams`.
    require_hparams: If True, the `hparams` passed to `experiment_fn` cannot be
      `None`.

  Returns:
    A experiment_fn with same signature.
  """
  def wrapped_experiment_fn(run_config, hparams):
    """Calls experiment_fn and checks the uid of `RunConfig`."""
    if not isinstance(run_config, run_config_lib.RunConfig):
      raise ValueError('`run_config` must be `RunConfig` instance')
    if not run_config.model_dir:
      raise ValueError(
          'Must specify a model directory `model_dir` in `run_config`.')
    if hparams is not None and not isinstance(hparams, hparam_lib.HParams):
      raise ValueError('`hparams` must be `HParams` instance')
    if require_hparams and hparams is None:
      raise ValueError('`hparams` cannot be `None`.')

    expected_uid = run_config.uid()
    experiment = experiment_fn(run_config, hparams)

    if not isinstance(experiment, Experiment):
      raise TypeError('Experiment builder did not return an Experiment '
                      'instance, got %s instead.' % type(experiment))

    if experiment.estimator.config.uid() != expected_uid:
      raise RuntimeError(
          '`RunConfig` instance is expected to be used by the `Estimator` '
          'inside the `Experiment`. expected {}, but got {}'.format(
              expected_uid, experiment.estimator.config.uid()))
    return experiment
  return wrapped_experiment_fn 
示例15
def __init__(self, model_dir=None, config=None):
    """Initializes a BaseEstimator instance.

    Args:
      model_dir: Directory to save model parameters, graph and etc. This can
        also be used to load checkpoints from the directory into a estimator to
        continue training a previously saved model. If `None`, the model_dir in
        `config` will be used if set. If both are set, they must be same.
      config: A RunConfig instance.
    """
    # Create a run configuration.
    if config is None:
      self._config = BaseEstimator._Config()
      logging.info('Using default config.')
    else:
      self._config = config

    if self._config.session_config is None:
      self._session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    else:
      self._session_config = self._config.session_config

    # Model directory.
    if (model_dir is not None) and (self._config.model_dir is not None):
      if model_dir != self._config.model_dir:
        # TODO(b/9965722): remove this suppression after it is no longer
        #                  necessary.
        # pylint: disable=g-doc-exception
        raise ValueError(
            "model_dir are set both in constructor and RunConfig, but with "
            "different values. In constructor: '{}', in RunConfig: "
            "'{}' ".format(model_dir, self._config.model_dir))

    self._model_dir = model_dir or self._config.model_dir
    if self._model_dir is None:
      self._model_dir = tempfile.mkdtemp()
      logging.warning('Using temporary folder as model directory: %s',
                      self._model_dir)
    if self._config.model_dir is None:
      self._config = self._config.replace(model_dir=self._model_dir)
    logging.info('Using config: %s', str(vars(self._config)))

    # Set device function depending if there are replicas or not.
    self._device_fn = _get_replica_device_setter(self._config)

    # Features and labels TensorSignature objects.
    # TODO(wicke): Rename these to something more descriptive
    self._features_info = None
    self._labels_info = None

    self._graph = None