|
import torch |
|
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class FlashSTUConfig(PretrainedConfig): |
|
model_type = "FlashSTU" |
|
|
|
def __init__( |
|
self, |
|
bsz: int = 1, |
|
n_embd: int = 1536, |
|
n_heads: int = 8, |
|
n_layers: int = 26, |
|
seq_len: int = 8192, |
|
window_size: int = 1024, |
|
vocab_size: int = 200064, |
|
mlp_scale: int = 12, |
|
bias: bool = False, |
|
dropout: float = 0.0, |
|
num_eigh: int = 24, |
|
use_hankel_L: bool = False, |
|
use_flash_fft: bool = True, |
|
use_approx: bool = True, |
|
use_attn: bool = True, |
|
softcap: float = 50.0, |
|
torch_dtype: torch.dtype = torch.bfloat16, |
|
**kwargs, |
|
): |
|
super().__init__(**kwargs) |
|
self.bsz = bsz |
|
self.n_embd = n_embd |
|
self.n_heads = n_heads |
|
self.n_layers = n_layers |
|
self.seq_len = seq_len |
|
self.window_size = window_size |
|
self.vocab_size = vocab_size |
|
self.hidden_size = n_embd |
|
self.intermediate_size = n_embd * mlp_scale |
|
self.hidden_act = "swish" |
|
self.bias = bias |
|
self.dropout = dropout |
|
self.num_eigh = num_eigh |
|
self.use_hankel_L = use_hankel_L |
|
self.use_flash_fft = use_flash_fft |
|
self.use_approx = use_approx |
|
self.use_attn = use_attn |
|
self.softcap = softcap |
|
self.torch_dtype = torch_dtype |
|
|