Source code for libreasr.lib.callbacks

"""
Contains Tensorboard Callback which performs testing while training.
"""

from fastai2.fp16_utils import convert_network
from fastai2.torch_core import to_float
from fastai2.learner import Callback

from torch.utils.tensorboard import SummaryWriter

FLUSH_SECS = 30
LOG_EVERY_N_STEPS = 4


[docs]class Tensorboard(Callback): toward_end = True def __init__( self, name=None, wandb=True, test=True, tests_per_epoch=4, mp=False, *args, **kwargs ): if name: self.writer = SummaryWriter(flush_secs=FLUSH_SECS, log_dir="runs/" + name) else: self.writer = SummaryWriter(flush_secs=FLUSH_SECS) self.tb_name = name self.tb_use_wandb = wandb self.is_fitting = False self.train_batch_count = 0 self.valid_batch_count = 0 self.test = test self.test_cnt = tests_per_epoch self.mp = mp self.steps = 0 def before_fit(self): if self.tb_use_wandb: import wandb wandb.init(project="asr", sync_tensorboard=True) wandb.watch(self.learn.model) self.is_fitting = True self.train_batch_count = 0 self.valid_batch_count = 0 self.writer.add_scalar("epochs", self.learn.n_epoch, self.train_batch_count) def before_batch(self): if self.is_fitting: if self.training: self.train_batch_count += 1 else: self.valid_batch_count += 1 def after_pred(self): pass def after_loss(self): if not self.training: self.writer.add_scalar( "loss/valid", self.learn.loss, self.valid_batch_count ) def after_backward(self): if self.training and self.steps % LOG_EVERY_N_STEPS == 0: try: loss = self.learn.smooth_loss self.writer.add_scalar("loss/train", loss, self.train_batch_count) except: pass self.writer.add_scalar( "data/seqlen/x", self.learn.xb[0][0].size(1), self.train_batch_count ) self.writer.add_scalar( "data/seqlen/y", self.learn.xb[0][1].size(1), self.train_batch_count ) self.writer.add_scalar( "data/batch_size", self.learn.xb[0][0].size(0), self.train_batch_count ) def after_step(self): if self.steps % LOG_EVERY_N_STEPS == 0: hyps = self.learn.opt.hypers[0] for k, v in hyps.items(): self.writer.add_scalar("hypers/" + k, v, self.train_batch_count) if hasattr(self.learn.opt, "opt"): if hasattr(self.learn.opt.opt, "extra"): extra = self.learn.opt.opt.extra for k, v in extra.items(): self.writer.add_scalar("extra/" + k, v, self.train_batch_count) self.steps += 1 def after_fit(self): self.is_fitting = False def after_batch(self): try: if self.training and self.test: a = self.learn.train_iter if self.test_cnt <= 0: return b = int((self.n_iter * (1 / self.test_cnt))) if a % b == 0: metrics = self.learn.test(mp=self.mp) i = self.train_batch_count for metric in metrics: for k, v in metric.items(): if "metrics" in k: self.writer.add_scalar(k, v, i) else: self.writer.add_text(k, v, i) i += 1 except Exception as e: print("failed to execute learn.test(...)") print(e) import traceback traceback.print_exc()