vikhyatk commited on
Commit
f10c8a2
1 Parent(s): 1e62d51

Upload Moondream

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. model.safetensors +3 -0
  3. moondream.py +17 -10
  4. vision_encoder.py +14 -28
config.json CHANGED
@@ -10,6 +10,6 @@
10
  "phi_config": {
11
  "model_type": "phi-msft"
12
  },
13
- "torch_dtype": "float32",
14
  "transformers_version": "4.36.2"
15
  }
 
10
  "phi_config": {
11
  "model_type": "phi-msft"
12
  },
13
+ "torch_dtype": "float16",
14
  "transformers_version": "4.36.2"
15
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:892e51df302d98a83974761c4f386caddbad2edd0e84f228d9935b4aed33ee25
3
+ size 3715037856
moondream.py CHANGED
@@ -1,10 +1,12 @@
1
  import torch
 
2
  from .vision_encoder import VisionEncoder
3
- from .text_model import TextModel
4
  from .configuration_moondream import MoondreamConfig
5
  from transformers import PreTrainedModel
6
  import re
7
 
 
 
8
 
9
  class Moondream(PreTrainedModel):
10
  config_class = MoondreamConfig
@@ -12,11 +14,16 @@ class Moondream(PreTrainedModel):
12
  def __init__(self, config):
13
  super().__init__(config)
14
  self.vision_encoder = VisionEncoder()
15
- self.text_model = TextModel(config)
 
 
 
 
 
16
 
17
  @property
18
  def device(self):
19
- return self.text_model.model.device
20
 
21
  def encode_image(self, image):
22
  return self.vision_encoder(image)
@@ -27,22 +34,22 @@ class Moondream(PreTrainedModel):
27
  txt, return_tensors="pt", add_special_tokens=False
28
  ).input_ids.to(self.device)
29
 
 
 
30
  # Add BOS token
31
  embeds = []
32
  embeds.append(
33
- self.text_model.text_emb(
34
- (torch.tensor([[tokenizer.bos_token_id]], device=self.device))
35
- )
36
  )
37
 
38
  if "<image>" not in prompt:
39
- embeds.append(self.text_model.text_emb(_tokenize(prompt)))
40
  else:
41
  assert prompt.count("<image>") == 1
42
  before, after = prompt.split("<image>")
43
- embeds.append(self.text_model.text_emb(_tokenize(f"{before}<image>")))
44
  embeds.append(image_embeds.to(self.device))
45
- embeds.append(self.text_model.text_emb(_tokenize(f"</image>{after}")))
46
 
47
  return torch.cat(embeds, dim=1)
48
 
@@ -67,7 +74,7 @@ class Moondream(PreTrainedModel):
67
 
68
  with torch.no_grad():
69
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
70
- output_ids = self.text_model.model.generate(
71
  inputs_embeds=inputs_embeds, **generate_config
72
  )
73
 
 
1
  import torch
2
+ from torch import nn
3
  from .vision_encoder import VisionEncoder
 
4
  from .configuration_moondream import MoondreamConfig
5
  from transformers import PreTrainedModel
6
  import re
7
 
8
+ from .modeling_phi import PhiForCausalLM
9
+ from .configuration_moondream import PhiConfig
10
 
11
  class Moondream(PreTrainedModel):
12
  config_class = MoondreamConfig
 
14
  def __init__(self, config):
15
  super().__init__(config)
16
  self.vision_encoder = VisionEncoder()
17
+
18
+ if type(config.phi_config) == dict:
19
+ phi_config = PhiConfig(**config.phi_config)
20
+ else:
21
+ phi_config = config.phi_config
22
+ self.text_model = PhiForCausalLM(phi_config)
23
 
24
  @property
25
  def device(self):
26
+ return self.text_model.device
27
 
28
  def encode_image(self, image):
29
  return self.vision_encoder(image)
 
34
  txt, return_tensors="pt", add_special_tokens=False
35
  ).input_ids.to(self.device)
36
 
37
+ text_emb = self.text_model.get_input_embeddings()
38
+
39
  # Add BOS token
40
  embeds = []
41
  embeds.append(
42
+ text_emb((torch.tensor([[tokenizer.bos_token_id]], device=self.device)))
 
 
43
  )
44
 
45
  if "<image>" not in prompt:
46
+ embeds.append(text_emb(_tokenize(prompt)))
47
  else:
48
  assert prompt.count("<image>") == 1
49
  before, after = prompt.split("<image>")
50
+ embeds.append(text_emb(_tokenize(f"{before}<image>")))
51
  embeds.append(image_embeds.to(self.device))
52
+ embeds.append(text_emb(_tokenize(f"</image>{after}")))
53
 
54
  return torch.cat(embeds, dim=1)
55
 
 
74
 
75
  with torch.no_grad():
76
  inputs_embeds = self.input_embeds(prompt, image_embeds, tokenizer)
77
+ output_ids = self.text_model.generate(
78
  inputs_embeds=inputs_embeds, **generate_config
79
  )
80
 
vision_encoder.py CHANGED
@@ -80,23 +80,18 @@ class VisionProjection(nn.Module):
80
  model_dim = 2048
81
  hidden_dim = model_dim * 4
82
 
83
- self.mlp1 = MLP(image_embedding_dim, hidden_dim, model_dim)
84
- self.mlp2 = MLP(model_dim, hidden_dim, model_dim)
85
- self.ln = nn.LayerNorm(model_dim)
86
 
87
  @property
88
  def device(self):
89
- return self.mlp1.fc1.weight.device
90
 
91
  def forward(self, x):
92
- x = self.mlp1(x)
93
- x = self.ln(x)
94
- x = x + self.mlp2(x)
95
- return x
96
 
97
 
98
- class VisionTower(nn.Module):
99
- def __init__(self):
100
  super().__init__()
101
 
102
  self.encoder = ModelHolder(
@@ -109,17 +104,6 @@ class VisionTower(nn.Module):
109
 
110
  self.projection = VisionProjection()
111
 
112
- def forward(self, x):
113
- x = self.encoder(x)
114
- x = self.projection(x)
115
- return x
116
-
117
-
118
- class VisionEncoder(nn.Module):
119
- def __init__(self) -> None:
120
- super().__init__()
121
-
122
- self.model = VisionTower()
123
  self.preprocess = Compose(
124
  [
125
  Resize(size=(378, 378), interpolation=InterpolationMode.BICUBIC),
@@ -131,20 +115,22 @@ class VisionEncoder(nn.Module):
131
 
132
  @property
133
  def device(self):
134
- return self.model.projection.mlp1.fc1.weight.device
135
 
136
  @property
137
  def dtype(self):
138
- return self.model.projection.mlp1.fc1.weight.dtype
139
 
140
  def __call__(self, image: Image) -> torch.Tensor:
141
  with torch.no_grad():
142
- image_vec = (
143
  self.preprocess(image.convert("RGB"))
144
  .unsqueeze(0)
145
  .to(self.device, dtype=self.dtype)
146
  )
147
- image_vec = rearrange(
148
- image_vec, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14
149
- )
150
- return self.model(image_vec)
 
 
 
80
  model_dim = 2048
81
  hidden_dim = model_dim * 4
82
 
83
+ self.mlp = MLP(image_embedding_dim, hidden_dim, model_dim)
 
 
84
 
85
  @property
86
  def device(self):
87
+ return self.mlp.fc1.weight.device
88
 
89
  def forward(self, x):
90
+ return self.mlp(x)
 
 
 
91
 
92
 
93
+ class VisionEncoder(nn.Module):
94
+ def __init__(self) -> None:
95
  super().__init__()
96
 
97
  self.encoder = ModelHolder(
 
104
 
105
  self.projection = VisionProjection()
106
 
 
 
 
 
 
 
 
 
 
 
 
107
  self.preprocess = Compose(
108
  [
109
  Resize(size=(378, 378), interpolation=InterpolationMode.BICUBIC),
 
115
 
116
  @property
117
  def device(self):
118
+ return self.projection.mlp.fc1.weight.device
119
 
120
  @property
121
  def dtype(self):
122
+ return self.projection.mlp.fc1.weight.dtype
123
 
124
  def __call__(self, image: Image) -> torch.Tensor:
125
  with torch.no_grad():
126
+ x = (
127
  self.preprocess(image.convert("RGB"))
128
  .unsqueeze(0)
129
  .to(self.device, dtype=self.dtype)
130
  )
131
+ x = rearrange(x, "b c (h p1) (w p2) -> b (h w) (c p1 p2)", p1=14, p2=14)
132
+
133
+ x = self.encoder(x)
134
+ x = self.projection(x)
135
+
136
+ return x