# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modifications by @iceychris:
# - run CPU-only (for inference)
#
"""Layer Normalized Long Short-Term Memory"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base_rnn import BaseRNN
__all__ = ["LayerNormLSTM"]
# @torch.jit.script
[docs]def LayerNormLSTMScript(
training: bool,
zoneout_prob: float,
input,
h0,
c0,
kernel,
recurrent_kernel,
bias,
gamma,
gamma_h,
beta_h,
zoneout_mask,
):
time_steps = input.shape[0]
batch_size = input.shape[1]
hidden_size = recurrent_kernel.shape[0]
h = [h0]
c = [c0]
Wx = F.layer_norm(input @ kernel, (hidden_size * 4,), weight=gamma[0])
for t in range(time_steps):
v = (
F.layer_norm(h[t] @ recurrent_kernel, (hidden_size * 4,), weight=gamma[1])
+ Wx[t]
+ bias
)
i, g, f, o = torch.chunk(v, 4, 1)
i = torch.sigmoid(i)
g = torch.tanh(g)
f = torch.sigmoid(f)
o = torch.sigmoid(o)
c.append(f * c[t] + i * g)
h.append(
o
* torch.tanh(
F.layer_norm(c[-1], (hidden_size,), weight=gamma_h, bias=beta_h)
)
)
if zoneout_prob:
if training:
h[-1] = (h[-1] - h[-2]) * zoneout_mask[t] + h[-2]
else:
h[-1] = zoneout_prob * h[-2] + (1 - zoneout_prob) * h[-1]
h = torch.stack(h)
c = torch.stack(c)
return h, c
[docs]class LayerNormLSTM(BaseRNN):
"""
Layer Normalized Long Short-Term Memory layer.
This LSTM layer applies layer normalization to the input, recurrent, and
output activations of a standard LSTM. The implementation is fused and
GPU-accelerated. DropConnect and Zoneout regularization are built-in, and
this layer allows setting a non-zero initial forget gate bias.
Details about the exact function this layer implements can be found at
https://github.com/lmnt-com/haste/issues/1.
See [\_\_init\_\_](#__init__) and [forward](#forward) for usage.
"""
def __init__(
self,
input_size,
hidden_size,
batch_first=False,
forget_bias=1.0,
dropout=0.0,
zoneout=0.0,
return_state_sequence=False,
):
"""
Initialize the parameters of the LSTM layer.
Arguments:
input_size: int, the feature dimension of the input.
hidden_size: int, the feature dimension of the output.
batch_first: (optional) bool, if `True`, then the input and output
tensors are provided as `(batch, seq, feature)`.
forget_bias: (optional) float, sets the initial bias of the forget gate
for this LSTM cell.
dropout: (optional) float, sets the dropout rate for DropConnect
regularization on the recurrent matrix.
zoneout: (optional) float, sets the zoneout rate for Zoneout
regularization.
return_state_sequence: (optional) bool, if `True`, the forward pass will
return the entire state sequence instead of just the final state. Note
that if the input is a padded sequence, the returned state will also
be a padded sequence.
Variables:
kernel: the input projection weight matrix. Dimensions
(input_size, hidden_size * 4) with `i,g,f,o` gate layout. Initialized
with Xavier uniform initialization.
recurrent_kernel: the recurrent projection weight matrix. Dimensions
(hidden_size, hidden_size * 4) with `i,g,f,o` gate layout. Initialized
with orthogonal initialization.
bias: the projection bias vector. Dimensions (hidden_size * 4) with
`i,g,f,o` gate layout. The forget gate biases are initialized to
`forget_bias` and the rest are zeros.
gamma: the input and recurrent normalization gain. Dimensions
(2, hidden_size * 4) with `gamma[0]` specifying the input gain and
`gamma[1]` specifying the recurrent gain. Initialized to ones.
gamma_h: the output normalization gain. Dimensions (hidden_size).
Initialized to ones.
beta_h: the output normalization bias. Dimensions (hidden_size).
Initialized to zeros.
"""
super().__init__(
input_size, hidden_size, batch_first, zoneout, return_state_sequence
)
if dropout < 0 or dropout > 1:
raise ValueError("LayerNormLSTM: dropout must be in [0.0, 1.0]")
if zoneout < 0 or zoneout > 1:
raise ValueError("LayerNormLSTM: zoneout must be in [0.0, 1.0]")
self.forget_bias = forget_bias
self.dropout = dropout
self.kernel = nn.Parameter(torch.empty(input_size, hidden_size * 4))
self.recurrent_kernel = nn.Parameter(torch.empty(hidden_size, hidden_size * 4))
self.bias = nn.Parameter(torch.empty(hidden_size * 4))
self.gamma = nn.Parameter(torch.empty(2, hidden_size * 4))
self.gamma_h = nn.Parameter(torch.empty(hidden_size))
self.beta_h = nn.Parameter(torch.empty(hidden_size))
self.reset_parameters()
[docs] def reset_parameters(self):
"""Resets this layer's parameters to their initial values."""
hidden_size = self.hidden_size
for i in range(4):
nn.init.xavier_uniform_(
self.kernel[:, i * hidden_size : (i + 1) * hidden_size]
)
nn.init.orthogonal_(
self.recurrent_kernel[:, i * hidden_size : (i + 1) * hidden_size]
)
nn.init.zeros_(self.bias)
nn.init.constant_(
self.bias[hidden_size * 2 : hidden_size * 3], self.forget_bias
)
nn.init.ones_(self.gamma)
nn.init.ones_(self.gamma_h)
nn.init.zeros_(self.beta_h)
[docs] def forward(self, input, state=None, lengths=None):
"""
Runs a forward pass of the LSTM layer.
Arguments:
input: Tensor, a batch of input sequences to pass through the LSTM.
Dimensions (seq_len, batch_size, input_size) if `batch_first` is
`False`, otherwise (batch_size, seq_len, input_size).
lengths: (optional) Tensor, list of sequence lengths for each batch
element. Dimension (batch_size). This argument may be omitted if
all batch elements are unpadded and have the same sequence length.
Returns:
output: Tensor, the output of the LSTM layer. Dimensions
(seq_len, batch_size, hidden_size) if `batch_first` is `False` (default)
or (batch_size, seq_len, hidden_size) if `batch_first` is `True`. Note
that if `lengths` was specified, the `output` tensor will not be
masked. It's the caller's responsibility to either not use the invalid
entries or to mask them out before using them.
(h_n, c_n): the hidden and cell states, respectively, for the last
sequence item. Dimensions (1, batch_size, hidden_size).
"""
input = self._permute(input)
state_shape = [1, input.shape[1], self.hidden_size]
state_shape = (state_shape, state_shape)
h0, c0 = self._get_state(input, state, state_shape)
h, c = self._impl(input, (h0[0], c0[0]), self._get_zoneout_mask(input))
state = self._get_final_state((h, c), lengths)
output = self._permute(h[1:])
return output, state
def _impl(self, input, state, zoneout_mask):
return LayerNormLSTMScript(
self.training,
self.zoneout,
input.contiguous(),
state[0].contiguous(),
state[1].contiguous(),
self.kernel.contiguous(),
F.dropout(self.recurrent_kernel, self.dropout, self.training).contiguous(),
self.bias.contiguous(),
self.gamma.contiguous(),
self.gamma_h.contiguous(),
self.beta_h.contiguous(),
zoneout_mask.contiguous(),
)