Source code for libreasr.lib.language

"""
Tokenization Utilities.
"""

import torch

import string

import youtokentome as yttm

# tasks
# - normal: just normal STT
# - fill: mark start and end tokens, rest is space
# - words: mark words with w
TASK = "normal"


[docs]class Language: def __init__(self, tokens): self.token_list = list(tokens.keys()) self.idx_list = list(tokens.values()) self.t2i = tokens self.i2t = {i: t for (t, i) in tokens.items()} def numericalize(self, text, sos=False): text = text.lower() text = text.strip() text = text.replace(self.SOS, "") text = text.replace(self.EOS, "") if sos: nummed = [self.iSOS] else: nummed = [] if TASK == "fill": sp = text.split(" ") newtxt = [] for word in sp: newtxt.append("x" * len(word)) text = " ".join(newtxt) elif TASK == "words": sp = text.split(" ") newtxt = [] for word in sp: newtxt.append("x") text = " ".join(newtxt) for c in text: try: nummed.append(self.get_idx(c)) except: continue return nummed + [self.iEOS] def denumericalize(self, nummed, strip_zeros=True): text = "" if not isinstance(nummed, list): nummed = [nummed] nummed = list(filter(lambda x: x != self.iSOS or x != self.iEOS, nummed)) for n in nummed: try: text += self.get_token(n, strip_zeros=strip_zeros) except: continue return text def get_idx(self, tok): return self.t2i[tok] def get_token(self, num, strip_zeros=False): if strip_zeros: if num == 0: return "" if isinstance(num, list): num = num[0] return self.i2t[num] @property def SOS(self): return self.token_list[1] @property def EOS(self): return self.token_list[2] @property def iSOS(self): return self.idx_list[1] @property def iEOS(self): return self.idx_list[2] @property def replaceable(self): l = list(self.t2i.values()) return l[11:] def randomize(self, t, p): x = t.clone() rpl = self.replaceable mask = torch.zeros(t.shape).uniform_() > p vals = torch.randint(low=min(rpl), high=max(rpl) + 1, size=t.shape) return torch.where(mask, x, vals) def __repr__(self): toks = list(self.t2i.keys()) return str((toks[:5], "...", toks[-5:], len(self))) def __len__(self): return len(self.t2i) def __getitem__(self, idx): self.get_token(idx)
[docs]class TokenizedLanguage(Language): def __init__(self, *args, model_file="tmp/tokenizer.yttm-model", **kwargs): super().__init__(*args, **kwargs) self.mf = model_file # load tokenizer self.tokenizer = yttm.BPE(model=model_file) def numericalize(self, text, sos=False, dropout=0): text = text.lower() text = text.strip() text = text.replace(self.SOS, "") text = text.replace(self.EOS, "") res = self.tokenizer.encode( [text], output_type=yttm.OutputType.ID, dropout_prob=dropout ) # unpack res = res[0] return res def denumericalize(self, nummed, strip_zeros=True): text = "" if not isinstance(nummed, list): nummed = [nummed] res = self.tokenizer.decode([nummed], ignore_ids=[0]) # unpack res = res[0] return res def get_idx(self, tok): return self.numericalize(tok)[0] def get_token(self, num, strip_zeros=False): return self.denumericalize(num)[0] def __len__(self): return self.tokenizer.vocab_size() def __repr__(self): bpe = self.tokenizer return str((bpe.vocab()[:5], "...", bpe.vocab()[-5:], len(self)))
[docs]def get_language( tokens=["<BLK>", "<s>", "</s>", "<UNK>", " ", ".", "!", "?", ",", "'", "-"], cls=TokenizedLanguage, **kwargs ): # create dictionary tokens = dict(zip(tokens, range(len(tokens)))) possible_chars = string.ascii_lowercase # + string.digits# + string.punctuation tokens.update( {_char: len(tokens) + _idx for (_idx, _char) in enumerate(possible_chars)} ) # create a language lang = cls(tokens, **kwargs) vocab_sz = len(lang) return lang, vocab_sz
[docs]def get_tokenizer(*args, **kwargs): return get_language(*args, **kwargs)