from typing import Callable import torch from torch import Tensor, nn from torch.nn import functional as F def ensure_tuple(val: int | tuple[int, ...], n: int = 2) -> tuple[int, ...]: if isinstance(val, int): return (val,) * n elif len(val) != n: raise ValueError(f"Expected a tuple of {n} values, but got {len(val)}: {val}") return val def use_fused_attn(): if hasattr(F, "scaled_dot_product_attention"): return True return False class QuickGELU(nn.Module): """ Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs """ def forward(self, input: Tensor) -> Tensor: return input * torch.sigmoid(1.702 * input) def get_act_layer(name: str) -> Callable[[], nn.Module]: match name: case "gelu": return nn.GELU case "quick_gelu": return QuickGELU case _: raise ValueError(f"Activation layer {name} not supported.")