File size: 1,003 Bytes
f664757 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from typing import Callable
import torch
from torch import Tensor, nn
from torch.nn import functional as F
def ensure_tuple(val: int | tuple[int, ...], n: int = 2) -> tuple[int, ...]:
if isinstance(val, int):
return (val,) * n
elif len(val) != n:
raise ValueError(f"Expected a tuple of {n} values, but got {len(val)}: {val}")
return val
def use_fused_attn():
if hasattr(F, "scaled_dot_product_attention"):
return True
return False
class QuickGELU(nn.Module):
"""
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
"""
def forward(self, input: Tensor) -> Tensor:
return input * torch.sigmoid(1.702 * input)
def get_act_layer(name: str) -> Callable[[], nn.Module]:
match name:
case "gelu":
return nn.GELU
case "quick_gelu":
return QuickGELU
case _:
raise ValueError(f"Activation layer {name} not supported.")
|