moondream1 / text_model.py
vikhyatk's picture
Upload Moondream
90f21ed verified
raw
history blame contribute delete
535 Bytes
from torch import nn
import transformers
from .modeling_phi import PhiForCausalLM
from .configuration_moondream import PhiConfig
transformers.logging.set_verbosity_error()
class TextModel(nn.Module):
def __init__(self, config) -> None:
super().__init__()
if type(config.phi_config) == dict:
phi_config = PhiConfig(**config.phi_config)
else:
phi_config = config.phi_config
self.model = PhiForCausalLM(phi_config)
self.text_emb = self.model.get_input_embeddings()