# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/13_callback.core.ipynb (unless otherwise specified).
__all__ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException',
'CancelBatchException', 'event', 'Callback', 'TrainEvalCallback', 'GatherPredsCallback',
'FetchPredsCallback']
# Cell
from ..data.all import *
from ..optimizer import *
# Cell
#nbdev_comment _all_ = ['CancelFitException', 'CancelEpochException', 'CancelTrainException', 'CancelValidException', 'CancelBatchException']
# Cell
_events = L.split('before_fit before_epoch before_train before_batch after_pred after_loss \
before_backward after_backward after_step after_cancel_batch after_batch after_cancel_train \
after_train before_validate after_cancel_validate after_validate after_cancel_epoch \
after_epoch after_cancel_fit after_fit')
mk_class('event', **_events.map_dict(),
doc="All possible events as attributes to get tab-completion and typo-proofing")
# Cell
#nbdev_comment _all_ = ['event']
# Cell
_inner_loop = "before_batch after_pred after_loss before_backward after_backward after_step after_cancel_batch after_batch".split()
# Cell
@funcs_kwargs(as_method=True)
class Callback(GetAttr):
"Basic class handling tweaks of the training loop by changing a `Learner` in various events"
_default,learn,run,run_train,run_valid = 'learn',None,True,True,True
_methods = _events
def __init__(self, **kwargs): assert not kwargs, f'Passed unknown events: {kwargs}'
def __repr__(self): return type(self).__name__
def __call__(self, event_name):
"Call `self.{event_name}` if it's defined"
_run = (event_name not in _inner_loop or (self.run_train and getattr(self, 'training', True)) or
(self.run_valid and not getattr(self, 'training', False)))
res = None
if self.run and _run: res = getattr(self, event_name, noop)()
if event_name=='after_fit': self.run=True #Reset self.run to True at each end of fit
return res
def __setattr__(self, name, value):
if hasattr(self.learn,name):
warn(f"You are setting an attribute ({name}) that also exists in the learner, so you're not setting it in the learner but in the callback. Use `self.learn.{name}` otherwise.")
super().__setattr__(name, value)
@property
def name(self):
"Name of the `Callback`, camel-cased and with '*Callback*' removed"
return class2attr(self, 'Callback')
# Cell
class TrainEvalCallback(Callback):
"`Callback` that tracks the number of iterations done and properly sets training/eval mode"
run_valid = False
def before_fit(self):
"Set the iter and epoch counters to 0, put the model and the right device"
self.learn.train_iter,self.learn.pct_train = 0,0.
if hasattr(self.dls, 'device'): self.model.to(self.dls.device)
if hasattr(self.model, 'reset'): self.model.reset()
def after_batch(self):
"Update the iter counter (in training mode)"
self.learn.pct_train += 1./(self.n_iter*self.n_epoch)
self.learn.train_iter += 1
def before_train(self):
"Set the model in training mode"
self.learn.pct_train=self.epoch/self.n_epoch
self.model.train()
self.learn.training=True
def before_validate(self):
"Set the model in validation mode"
self.model.eval()
self.learn.training=False
# Cell
if not hasattr(defaults, 'callbacks'): defaults.callbacks = [TrainEvalCallback]
# Cell
#TODO: save_targs and save_preds only handle preds/targets that have one tensor, not tuples of tensors.
class GatherPredsCallback(Callback):
"`Callback` that saves the predictions and targets, optionally `with_loss`"
def __init__(self, with_input=False, with_loss=False, save_preds=None, save_targs=None, concat_dim=0):
store_attr(self, "with_input,with_loss,save_preds,save_targs,concat_dim")
def before_batch(self):
if self.with_input: self.inputs.append((to_detach(self.xb)))
def before_validate(self):
"Initialize containers"
self.preds,self.targets = [],[]
if self.with_input: self.inputs = []
if self.with_loss: self.losses = []
def after_batch(self):
"Save predictions, targets and potentially losses"
if not hasattr(self, 'pred'): return
preds,targs = to_detach(self.pred),to_detach(self.yb)
if self.save_preds is None: self.preds.append(preds)
else: (self.save_preds/str(self.iter)).save_array(preds)
if self.save_targs is None: self.targets.append(targs)
else: (self.save_targs/str(self.iter)).save_array(targs[0])
if self.with_loss:
bs = find_bs(self.yb)
loss = self.loss if self.loss.numel() == bs else self.loss.view(bs,-1).mean(1)
self.losses.append(to_detach(loss))
def after_validate(self):
"Concatenate all recorded tensors"
if not hasattr(self, 'preds'): return
if self.with_input: self.inputs = detuplify(to_concat(self.inputs, dim=self.concat_dim))
if not self.save_preds: self.preds = detuplify(to_concat(self.preds, dim=self.concat_dim))
if not self.save_targs: self.targets = detuplify(to_concat(self.targets, dim=self.concat_dim))
if self.with_loss: self.losses = to_concat(self.losses)
def all_tensors(self):
res = [None if self.save_preds else self.preds, None if self.save_targs else self.targets]
if self.with_input: res = [self.inputs] + res
if self.with_loss: res.append(self.losses)
return res
# Cell
class FetchPredsCallback(Callback):
"A callback to fetch predictions during the training loop"
remove_on_fetch = True
def __init__(self, ds_idx=1, dl=None, with_input=False, with_decoded=False, cbs=None):
self.cbs = L(cbs)
store_attr(self, 'ds_idx,dl,with_input,with_decoded')
def after_validate(self):
to_rm = L(cb for cb in self.learn.cbs if getattr(cb, 'remove_on_fetch', False))
with self.learn.removed_cbs(to_rm + self.cbs) as learn:
self.preds = learn.get_preds(ds_idx=self.ds_idx, dl=self.dl,
with_input=self.with_input, with_decoded=self.with_decoded, inner=True)
# Cell
_ex_docs = dict(
CancelFitException="Skip the rest of this batch and go to `after_batch`",
CancelEpochException="Skip the rest of the training part of the epoch and go to `after_train`",
CancelTrainException="Skip the rest of the validation part of the epoch and go to `after_validate`",
CancelValidException="Skip the rest of this epoch and go to `after_epoch`",
CancelBatchException="Interrupts training and go to `after_fit`")
for c,d in _ex_docs.items(): mk_class(c,sup=Exception,doc=d)