Source code for libreasr.lib.layers.custom_rnn

import math
import random

import torch
from torch.nn import Parameter, ParameterList
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from IPython.core.debugger import set_trace


ZONEOUT = 0.01
DEVICES = ["CPU", "GPU"]
RNN_TYPES = ["LSTM", "GRU", "NBRC"]
USE_PYTORCH = True


[docs]def get_rnn_impl(device, rnn_type, layer_norm=False): assert device in DEVICES assert rnn_type in RNN_TYPES if device == "GPU": if rnn_type == "LSTM": if layer_norm: # from haste_pytorch import LayerNormLSTM as RNN from torch.nn import LSTM as RNN else: # from haste_pytorch import LSTM as RNN from torch.nn import LSTM as RNN if rnn_type == "GRU": # from haste_pytorch import GRU as RNN from torch.nn import GRU as RNN if rnn_type == "NBRC": raise Exception("NBRC GPU not available") if device == "CPU": if rnn_type == "LSTM": if layer_norm: # from .haste import LayerNormLSTM as RNN from torch.nn import LSTM as RNN else: # from .haste import LSTM as RNN from torch.nn import LSTM as RNN if rnn_type == "GRU": # from .haste import GRU as RNN from torch.nn import GRU as RNN if rnn_type == "NBRC": from .haste import NBRC as RNN return RNN
[docs]def get_weight_attrs(rnn_type, layer_norm): attrs = [ "kernel", "recurrent_kernel", "bias", ] if rnn_type == "GRU" or rnn_type == "NBRC": attrs += [ "recurrent_bias", ] if layer_norm: attrs += [ "gamma", "gamma_h", "beta_h", ] return attrs
[docs]def copy_weights(_from, _to, attrs): for attr in attrs: setattr(_to, attr, getattr(_from, attr))
[docs]def get_initial_state(rnn_type, hidden_size, init=torch.zeros): if rnn_type == "LSTM": h = nn.Parameter(init(2, 1, 1, hidden_size)) tmp = init(2, 1, 1, hidden_size) else: h = nn.Parameter(init(1, 1, 1, hidden_size)) tmp = init(1, 1, 1, hidden_size) return h, tmp
[docs]class CustomRNN(nn.Module): def __init__( self, input_size, hidden_size, num_layers=1, batch_first=True, rnn_type="LSTM", reduction_indices=[], reduction_factors=[], reduction_drop=True, rezero=False, layer_norm=False, utsp=0.9, ): super().__init__() self.batch_first = batch_first self.hidden_size = hidden_size self._is = [input_size] + [hidden_size] * (num_layers - 1) self._os = [hidden_size] * num_layers self.rnn_type = rnn_type # reduction assert len(reduction_indices) == len(reduction_factors) self.reduction_indices = reduction_indices self.reduction_factors = reduction_factors # learnable state & temporary state self.hs = nn.ParameterList() for hidden_size in self._os: h, tmp = get_initial_state(rnn_type, hidden_size) self.hs.append(h) # state cache (key: bs, value: state) self.cache = {} # norm (BN or LN) self.bns = nn.ModuleList() for i, o in zip(self._is, self._os): norm = nn.BatchNorm1d(o) # norm = nn.LayerNorm(o) self.bns.append(norm) # rezero self.rezero = rezero # percentage of carrying over last state self.utsp = utsp def convert_to_cpu(self): return self def convert_to_gpu(self): return self def forward_one_rnn( self, x, i, state=None, should_use_tmp_state=False, lengths=None ): bs = x.size(0) if state is None: s = self.cache[bs][i] if self.cache.get(bs) is not None else None is_tmp_state_possible = self.training and s is not None if is_tmp_state_possible and should_use_tmp_state: # temporary state pass else: # learnable state if self.hs[i].size(0) == 2: s = [] for h in self.hs[i]: s.append(h.expand(1, bs, self._os[i]).contiguous()) s = tuple(s) else: s = self.hs[i][0].expand(1, bs, self._os[i]).contiguous() else: s = state[i] if self.rnn_type == "LSTM" or self.rnn_type == "GRU": # PyTorch if lengths is not None: seq = pack_padded_sequence( x, lengths, batch_first=True, enforce_sorted=False ) seq, new_state = self.rnns[i](seq, s) x, _ = pad_packed_sequence(seq, batch_first=True) return (x, new_state) else: return self.rnns[i](x, s) else: # haste return self.rnns[i](x, s, lengths=lengths if lengths is not None else None) def forward(self, x, state=None, lengths=None): bs = x.size(0) residual = 0.0 new_states = [] suts = random.random() > (1.0 - self.utsp) for i, rnn in enumerate(self.rnns): # reduce if necessary if i in self.reduction_indices: idx = self.reduction_indices.index(i) r_f = self.reduction_factors[idx] # to [N, H, T] x = x.permute(0, 2, 1) x = x.unfold(-1, r_f, r_f) x = x.permute(0, 2, 1, 3) # keep last # x = x[:,:,:,-1] # or take the mean x = x.mean(-1) # also reduce lengths if lengths is not None: lengths = lengths // r_f # apply rnn inp = x x, new_state = self.forward_one_rnn( inp, i, state=state, should_use_tmp_state=suts, lengths=lengths, ) # apply norm x = x.permute(0, 2, 1) x = self.bns[i](x) x = x.permute(0, 2, 1) # apply residual if self.rezero: if torch.is_tensor(residual) and residual.shape == x.shape: x = x + residual # store new residual residual = inp new_states.append(new_state) if self.training: if len(new_states[0]) == 2: self.cache[bs] = [ (h.detach().contiguous(), c.detach().contiguous()) for (h, c) in new_states ] else: self.cache[bs] = [h.detach() for h in new_states] return x, new_states
[docs]class CustomGPURNN(CustomRNN): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._args = args self._kwargs = kwargs RNN = get_rnn_impl("GPU", self.rnn_type, kwargs["layer_norm"]) self.rnns = nn.ModuleList() for i, o in zip(self._is, self._os): # r = RNN(i, o, batch_first=self.batch_first, zoneout=ZONEOUT) r = RNN(i, o, batch_first=self.batch_first) self.rnns.append(r) def convert_to_cpu(self): if USE_PYTORCH: return self.to("cpu") dev = next(self.parameters()).device inst = CustomCPURNN(*self._args, **self._kwargs) attrs = get_weight_attrs(self.rnn_type, self._kwargs["layer_norm"]) for i, rnn in enumerate(self.rnns): grabbed_rnn = inst.rnns[i] copy_weights(rnn, grabbed_rnn, attrs) return inst.to("cpu")
[docs]class CustomCPURNN(CustomRNN): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._args = args self._kwargs = kwargs RNN = get_rnn_impl("CPU", self.rnn_type, kwargs["layer_norm"]) self.rnns = nn.ModuleList() for i, o in zip(self._is, self._os): # r = RNN(i, o, batch_first=self.batch_first, zoneout=ZONEOUT) r = RNN(i, o, batch_first=self.batch_first) self.rnns.append(r) def convert_to_gpu(self): dev = next(self.parameters()).device if USE_PYTORCH or self.rnn_type == "NBRC": return self.to(dev) inst = CustomGPURNN(*self._args, **self._kwargs) attrs = get_weight_attrs(self.rnn_type, self._kwargs["layer_norm"]) for i, rnn in enumerate(self.rnns): grabbed_rnn = inst.rnns[i] copy_weights(rnn, grabbed_rnn, attrs) return inst.to(dev)