DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

网友投稿 725 2022-05-30

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

目录

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

函数代码实现

tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

函数功能解读

"""Basic LSTM recurrent network cell.

The implementation is based on: http://arxiv.org/abs/1409.2329.

We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.

It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}

that follows.

"""

def __init__(self,

num_units,

forget_bias=1.0,

state_is_tuple=True,

activation=None,

reuse=None,

name=None,

dtype=None):

"""Initialize the basic LSTM cell.

基本LSTM递归网络单元。

实现基于:http://arxiv.org/abs/1409.2329。

我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。

它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。

Args:

num_units: int, The number of units in the LSTM cell.

forget_bias: float, The bias added to forget gates (see above).

Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.

state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.

activation: Activation function of the inner states.  Default: `tanh`.

reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.

name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.

dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.

When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.

"""

参数:

num_units: int类型, LSTM单元中的单元数。

forget_bias: float类型,偏见添加到忘记门(见上面)。

从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。

state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。

activation: 内部状态的激活功能。默认值tanh激活函数。

reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。

name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。

dtype:该层的默认dtype(默认为‘None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。

从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。

”“”

函数代码实现

@tf_export("nn.rnn_cell.BasicLSTMCell")

class BasicLSTMCell(LayerRNNCell):

"""Basic LSTM recurrent network cell.

The implementation is based on: http://arxiv.org/abs/1409.2329.

We add forget_bias (default: 1) to the biases of the forget gate in order to

reduce the scale of forgetting in the beginning of the training.

It does not allow cell clipping, a projection layer, and does not

use peep-hole connections: it is the basic baseline.

For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}

that follows.

"""

def __init__(self,

num_units,

forget_bias=1.0,

state_is_tuple=True,

activation=None,

reuse=None,

name=None,

dtype=None):

"""Initialize the basic LSTM cell.

Args:

num_units: int, The number of units in the LSTM cell.

forget_bias: float, The bias added to forget gates (see above).

Must set to `0.0` manually when restoring from CudnnLSTM-trained

checkpoints.

state_is_tuple: If True, accepted and returned states are 2-tuples of

the `c_state` and `m_state`. If False, they are concatenated

along the column axis. The latter behavior will soon be deprecated.

activation: Activation function of the inner states. Default: `tanh`.

reuse: (optional) Python boolean describing whether to reuse variables

in an existing scope. If not `True`, and the existing scope already has

the given variables, an error is raised.

name: String, the name of the layer. Layers with the same name will

share weights, but to avoid mistakes we require reuse=True in such

cases.

dtype: Default dtype of the layer (default of `None` means use the type

of the first input). Required when `build` is called before `call`.

When restoring from CudnnLSTM-trained checkpoints, must use

`CudnnCompatibleLSTMCell` instead.

"""

super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype)

if not state_is_tuple:

logging.warn("%s: Using a concatenated state is slower and will soon be "

"deprecated. Use state_is_tuple=True.", self)

# Inputs must be 2-dimensional.

self.input_spec = base_layer.InputSpec(ndim=2)

self._num_units = num_units

self._forget_bias = forget_bias

self._state_is_tuple = state_is_tuple

self._activation = activation or math_ops.tanh

@property

def state_size(self):

return (LSTMStateTuple(self._num_units, self._num_units)

if self._state_is_tuple else 2 * self._num_units)

@property

def output_size(self):

return self._num_units

def build(self, inputs_shape):

if inputs_shape[1].value is None:

raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s"

% inputs_shape)

input_depth = inputs_shape[1].value

h_depth = self._num_units

self._kernel = self.add_variable(

_WEIGHTS_VARIABLE_NAME,

shape=[input_depth + h_depth, 4 * self._num_units])

self._bias = self.add_variable(

_BIAS_VARIABLE_NAME,

DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读

shape=[4 * self._num_units],

initializer=init_ops.zeros_initializer(dtype=self.dtype))

self.built = True

def call(self, inputs, state):

"""Long short-term memory cell (LSTM).

Args:

inputs: `2-D` tensor with shape `[batch_size, input_size]`.

state: An `LSTMStateTuple` of state tensors, each shaped

`[batch_size, num_units]`, if `state_is_tuple` has been set to

`True`. Otherwise, a `Tensor` shaped

`[batch_size, 2 * num_units]`.

Returns:

A pair containing the new hidden state, and the new state (either a

`LSTMStateTuple` or a concatenated state, depending on

`state_is_tuple`).

"""

sigmoid = math_ops.sigmoid

one = constant_op.constant(1, dtype=dtypes.int32)

# Parameters of gates are concatenated into one multiply for efficiency.

if self._state_is_tuple:

c, h = state

else:

c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one)

gate_inputs = math_ops.matmul(

array_ops.concat([inputs, h], 1), self._kernel)

gate_inputs = nn_ops.bias_add(gate_inputs, self._bias)

# i = input_gate, j = new_input, f = forget_gate, o = output_gate

i, j, f, o = array_ops.split(

value=gate_inputs, num_or_size_splits=4, axis=one)

forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype)

# Note that using `add` and `multiply` instead of `+` and `*` gives a

# performance improvement. So using those at the cost of readability.

add = math_ops.add

multiply = math_ops.multiply

new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))),

multiply(sigmoid(i), self._activation(j)))

new_h = multiply(self._activation(new_c), sigmoid(o))

if self._state_is_tuple:

new_state = LSTMStateTuple(new_c, new_h)

else:

new_state = array_ops.concat([new_c, new_h], 1)

return new_h, new_state

机器学习 神经网络

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:Java程序员7月份裸辞找工作,今天终于拿到美团offer
下一篇:React Native之组件Component与PureComponent
相关文章