File size: 1,370 Bytes
2a44926 aab9a40 2a44926 aab9a40 2a44926 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from torch import nn
from transformers import PreTrainedModel, AutoModel, AutoConfig
from .rna_torsionbert_config import RNATorsionBertConfig
class RNATorsionBERTModel(PreTrainedModel):
config_class = RNATorsionBertConfig
def __init__(self, config):
super().__init__(config)
self.init_model(config.k)
self.dnabert = AutoModel.from_pretrained(
self.model_name, config=self.dnabert_config, trust_remote_code=True
)
self.regressor = nn.Sequential(
nn.LayerNorm(self.dnabert_config.hidden_size),
nn.Linear(self.dnabert_config.hidden_size, config.hidden_size),
nn.GELU(),
nn.Linear(config.hidden_size, config.num_classes),
)
self.activation = nn.Tanh()
def init_model(self, k: int):
model_name = f"zhihan1996/DNA_bert_{k}"
revisions = {3: "ed28178", 4: "c8499f0", 5: "c296157", 6: "a79a8fd"}
dnabert_config = AutoConfig.from_pretrained(
model_name,
revision=revisions[k],
trust_remote_code=True,
)
self.dnabert_config = dnabert_config
self.model_name = model_name
def forward(self, tensor):
z = self.dnabert(**tensor).last_hidden_state
output = self.regressor(z)
output = self.activation(output)
return {"logits": output} |