from torch import nn from transformers.modeling_utils import PreTrainedModel from .configuration_my_model import MyModelConfig class MyModelPretrainedModel(PreTrainedModel): config_class = MyModelConfig class MyModel(MyModelPretrainedModel): def __init__(self, config: MyModelConfig): super().__init__(config) self.config = config self.n_layers = config.n_layers self.hidden_dim = config.hidden_dim self.linear = nn.Linear(config.hidden_dim, config.hidden_dim)