Source code for libreasr.lib.layers.haste.lstm

# Copyright 2020 LMNT, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Modifications by @iceychris:
# - run CPU-only (for inference)
#

"""Long Short-Term Memory"""


import torch
import torch.nn as nn
import torch.nn.functional as F

from .base_rnn import BaseRNN


__all__ = ["LSTM"]


# @torch.jit.script
[docs]def LSTMScript( training: bool, zoneout_prob: float, input, h0, c0, kernel, recurrent_kernel, bias, zoneout_mask, ): time_steps = input.shape[0] batch_size = input.shape[1] hidden_size = recurrent_kernel.shape[0] h = [h0] c = [c0] Wx = input @ kernel for t in range(time_steps): v = h[t] @ recurrent_kernel + Wx[t] + bias i, g, f, o = torch.chunk(v, 4, 1) i = torch.sigmoid(i) g = torch.tanh(g) f = torch.sigmoid(f) o = torch.sigmoid(o) c.append(f * c[t] + i * g) h.append(o * torch.tanh(c[-1])) if zoneout_prob: if training: h[-1] = (h[-1] - h[-2]) * zoneout_mask[t] + h[-2] else: h[-1] = zoneout_prob * h[-2] + (1 - zoneout_prob) * h[-1] h = torch.stack(h) c = torch.stack(c) return h, c
[docs]class LSTM(BaseRNN): """ Long Short-Term Memory layer. This LSTM layer offers a fused, GPU-accelerated PyTorch op for inference and training. Although this implementation is comparable in performance to cuDNN's LSTM, it offers additional options not typically found in other high-performance implementations. DropConnect and Zoneout regularization are built-in, and this layer allows setting a non-zero initial forget gate bias. See [\_\_init\_\_](#__init__) and [forward](#forward) for general usage. See [from_native_weights](#from_native_weights) and [to_native_weights](#to_native_weights) for compatibility with PyTorch LSTMs. """ def __init__( self, input_size, hidden_size, batch_first=False, forget_bias=1.0, dropout=0.0, zoneout=0.0, return_state_sequence=False, ): """ Initialize the parameters of the LSTM layer. Arguments: input_size: int, the feature dimension of the input. hidden_size: int, the feature dimension of the output. batch_first: (optional) bool, if `True`, then the input and output tensors are provided as `(batch, seq, feature)`. forget_bias: (optional) float, sets the initial bias of the forget gate for this LSTM cell. dropout: (optional) float, sets the dropout rate for DropConnect regularization on the recurrent matrix. zoneout: (optional) float, sets the zoneout rate for Zoneout regularization. return_state_sequence: (optional) bool, if `True`, the forward pass will return the entire state sequence instead of just the final state. Note that if the input is a padded sequence, the returned state will also be a padded sequence. Variables: kernel: the input projection weight matrix. Dimensions (input_size, hidden_size * 4) with `i,g,f,o` gate layout. Initialized with Xavier uniform initialization. recurrent_kernel: the recurrent projection weight matrix. Dimensions (hidden_size, hidden_size * 4) with `i,g,f,o` gate layout. Initialized with orthogonal initialization. bias: the projection bias vector. Dimensions (hidden_size * 4) with `i,g,f,o` gate layout. The forget gate biases are initialized to `forget_bias` and the rest are zeros. """ super().__init__( input_size, hidden_size, batch_first, zoneout, return_state_sequence ) if dropout < 0 or dropout > 1: raise ValueError("LSTM: dropout must be in [0.0, 1.0]") if zoneout < 0 or zoneout > 1: raise ValueError("LSTM: zoneout must be in [0.0, 1.0]") self.forget_bias = forget_bias self.dropout = dropout self.kernel = nn.Parameter(torch.empty(input_size, hidden_size * 4)) self.recurrent_kernel = nn.Parameter(torch.empty(hidden_size, hidden_size * 4)) self.bias = nn.Parameter(torch.empty(hidden_size * 4)) self.reset_parameters()
[docs] def to_native_weights(self): """ Converts Haste LSTM weights to native PyTorch LSTM weights. Returns: weight_ih_l0: Parameter, the input-hidden weights of the LSTM layer. weight_hh_l0: Parameter, the hidden-hidden weights of the LSTM layer. bias_ih_l0: Parameter, the input-hidden bias of the LSTM layer. bias_hh_l0: Parameter, the hidden-hidden bias of the LSTM layer. """ def reorder_weights(w): i, g, f, o = torch.chunk(w, 4, dim=-1) return torch.cat([i, f, g, o], dim=-1) kernel = reorder_weights(self.kernel).permute(1, 0).contiguous() recurrent_kernel = ( reorder_weights(self.recurrent_kernel).permute(1, 0).contiguous() ) half_bias = reorder_weights(self.bias) / 2.0 kernel = torch.nn.Parameter(kernel) recurrent_kernel = torch.nn.Parameter(recurrent_kernel) bias1 = torch.nn.Parameter(half_bias) bias2 = torch.nn.Parameter(half_bias.clone()) return kernel, recurrent_kernel, bias1, bias2
[docs] def from_native_weights(self, weight_ih_l0, weight_hh_l0, bias_ih_l0, bias_hh_l0): """ Copies and converts the provided PyTorch LSTM weights into this layer. Arguments: weight_ih_l0: Parameter, the input-hidden weights of the PyTorch LSTM layer. weight_hh_l0: Parameter, the hidden-hidden weights of the PyTorch LSTM layer. bias_ih_l0: Parameter, the input-hidden bias of the PyTorch LSTM layer. bias_hh_l0: Parameter, the hidden-hidden bias of the PyTorch LSTM layer. """ def reorder_weights(w): i, f, g, o = torch.chunk(w, 4, dim=-1) return torch.cat([i, g, f, o], dim=-1) kernel = reorder_weights(weight_ih_l0.permute(1, 0)).contiguous() recurrent_kernel = reorder_weights(weight_hh_l0.permute(1, 0)).contiguous() bias = reorder_weights(bias_ih_l0 + bias_hh_l0).contiguous() self.kernel = nn.Parameter(kernel) self.recurrent_kernel = nn.Parameter(recurrent_kernel) self.bias = nn.Parameter(bias)
[docs] def reset_parameters(self): """Resets this layer's parameters to their initial values.""" hidden_size = self.hidden_size for i in range(4): nn.init.xavier_uniform_( self.kernel[:, i * hidden_size : (i + 1) * hidden_size] ) nn.init.orthogonal_( self.recurrent_kernel[:, i * hidden_size : (i + 1) * hidden_size] ) nn.init.zeros_(self.bias) nn.init.constant_( self.bias[hidden_size * 2 : hidden_size * 3], self.forget_bias )
[docs] def forward(self, input, state=None, lengths=None): """ Runs a forward pass of the LSTM layer. Arguments: input: Tensor, a batch of input sequences to pass through the LSTM. Dimensions (seq_len, batch_size, input_size) if `batch_first` is `False`, otherwise (batch_size, seq_len, input_size). lengths: (optional) Tensor, list of sequence lengths for each batch element. Dimension (batch_size). This argument may be omitted if all batch elements are unpadded and have the same sequence length. Returns: output: Tensor, the output of the LSTM layer. Dimensions (seq_len, batch_size, hidden_size) if `batch_first` is `False` (default) or (batch_size, seq_len, hidden_size) if `batch_first` is `True`. Note that if `lengths` was specified, the `output` tensor will not be masked. It's the caller's responsibility to either not use the invalid entries or to mask them out before using them. (h_n, c_n): the hidden and cell states, respectively, for the last sequence item. Dimensions (1, batch_size, hidden_size). """ input = self._permute(input) state_shape = [1, input.shape[1], self.hidden_size] state_shape = (state_shape, state_shape) h0, c0 = self._get_state(input, state, state_shape) h, c = self._impl(input, (h0[0], c0[0]), self._get_zoneout_mask(input)) state = self._get_final_state((h, c), lengths) output = self._permute(h[1:]) return output, state
def _impl(self, input, state, zoneout_mask): return LSTMScript( self.training, self.zoneout, input.contiguous(), state[0].contiguous(), state[1].contiguous(), self.kernel.contiguous(), F.dropout(self.recurrent_kernel, self.dropout, self.training).contiguous(), self.bias.contiguous(), zoneout_mask.contiguous(), )