Python源码示例:tensorflow.contrib.rnn.ResidualWrapper()

示例1
def _make_cell(self, hidden_size=None):
        """Create a single RNN cell"""
        cell = self.cell_type(hidden_size or self.hidden_size)
        if self.dropout:
            cell = rnn.DropoutWrapper(cell, self.keep_prob)
        if self.residual:
            cell = rnn.ResidualWrapper(cell)
        return cell 
示例2
def build_single_cell(self, n_hidden, use_residual):
        """构建一个单独的rnn cell
        Args:
            n_hidden: 隐藏层神经元数量
            use_residual: 是否使用residual wrapper
        """

        if self.cell_type == 'gru':
            cell_type = GRUCell
        else:
            cell_type = LSTMCell

        cell = cell_type(n_hidden)

        if self.use_dropout:
            cell = DropoutWrapper(
                cell,
                dtype=tf.float32,
                output_keep_prob=self.keep_prob_placeholder,
                seed=self.seed
            )

        if use_residual:
            cell = ResidualWrapper(cell)

        return cell 
示例3
def create_cell(unit_type, hidden_units, num_layers, use_residual=False, input_keep_prob=1.0, output_keep_prob=1.0, devices=None):
    if unit_type == 'lstm':
        def _new_cell():
            return tf.nn.rnn_cell.BasicLSTMCell(hidden_units)
    elif unit_type == 'gru':
        def _new_cell():
            return tf.contrib.rnn.GRUCell(hidden_units)
    else:
        raise ValueError('cell_type must be either lstm or gru')

    def _new_cell_wrapper(residual_connection=False, device_id=None):
        c = _new_cell()

        if input_keep_prob < 1.0 or output_keep_prob < 1.0:
            c = rnn.DropoutWrapper(c, input_keep_prob=input_keep_prob, output_keep_prob=output_keep_prob)

        if residual_connection:
            c = rnn.ResidualWrapper(c)

        if device_id:
            c = rnn.DeviceWrapper(c, device_id)

        return c

    if num_layers > 1:
        cells = []

        for i in range(num_layers):
            is_residual = True if use_residual and i > 0 else False
            cells.append(_new_cell_wrapper(is_residual, devices[i] if devices else None))

        return tf.contrib.rnn.MultiRNNCell(cells)
    else:
        return _new_cell_wrapper(device_id=devices[0] if devices else None) 
示例4
def _single_cell(unit_type,
                 num_units,
                 forget_bias,
                 dropout,
                 mode,
                 residual_connection=False,
                 residual_fn=None,
                 trainable=True):
  """Create an instance of a single RNN cell."""
  # dropout (= 1 - keep_prob) is set to 0 during eval and infer
  dropout = dropout if mode == tf.estimator.ModeKeys.TRAIN else 0.0

  # Cell Type
  if unit_type == "lstm":
    single_cell = contrib_rnn.LSTMCell(
        num_units, forget_bias=forget_bias, trainable=trainable)
  elif unit_type == "gru":
    single_cell = contrib_rnn.GRUCell(num_units, trainable=trainable)
  elif unit_type == "layer_norm_lstm":
    single_cell = contrib_rnn.LayerNormBasicLSTMCell(
        num_units,
        forget_bias=forget_bias,
        layer_norm=True,
        trainable=trainable)
  elif unit_type == "nas":
    single_cell = contrib_rnn.NASCell(num_units, trainable=trainable)
  else:
    raise ValueError("Unknown unit type %s!" % unit_type)

  # Dropout (= 1 - keep_prob).
  if dropout > 0.0:
    single_cell = contrib_rnn.DropoutWrapper(
        cell=single_cell, input_keep_prob=(1.0 - dropout))

  # Residual.
  if residual_connection:
    single_cell = contrib_rnn.ResidualWrapper(
        single_cell, residual_fn=residual_fn)

  return single_cell 
示例5
def rnn_cell(rnn_cell_size, dropout_keep_prob, residual, is_training=True):
  """Builds an LSTMBlockCell based on the given parameters."""
  dropout_keep_prob = dropout_keep_prob if is_training else 1.0
  cells = []
  for i in range(len(rnn_cell_size)):
    cell = rnn.LSTMBlockCell(rnn_cell_size[i])
    if residual:
      cell = rnn.ResidualWrapper(cell)
      if i == 0 or rnn_cell_size[i] != rnn_cell_size[i - 1]:
        cell = rnn.InputProjectionWrapper(cell, rnn_cell_size[i])
    cell = rnn.DropoutWrapper(
        cell,
        input_keep_prob=dropout_keep_prob)
    cells.append(cell)
  return rnn.MultiRNNCell(cells)