Source code for libreasr.lib.layers.mish

import torch
import torch.nn as nn
import torch.nn.functional as F


def _mish_fwd(x):
    return x.mul(torch.tanh(F.softplus(x)))


def _mish_bwd(x, grad_output):
    x_sigmoid = torch.sigmoid(x)
    x_tanh_sp = F.softplus(x).tanh()
    return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp))


[docs]class MishAutoFn(torch.autograd.Function):
[docs] @staticmethod def forward(ctx, x): ctx.save_for_backward(x) return _mish_fwd(x)
[docs] @staticmethod def backward(ctx, grad_output): x = ctx.saved_variables[0] return _mish_bwd(x, grad_output)
class _Mish(nn.Module): def forward(self, x): return MishAutoFn.apply(x) try: from fastai2.layers import Mish Mish = Mish print("Using Mish activation from fastai2.") except: Mish = _Mish print("Using Mish activation from lib.layers.")