Johannes commited on
Commit
2420b7f
β€’
1 Parent(s): df20c94
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.jpeg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ weights/
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Capdec 0 Shot Image Captioning
3
  emoji: πŸ‘
4
  colorFrom: pink
5
  colorTo: blue
 
1
  ---
2
+ title: CapDec Image Captioning
3
  emoji: πŸ‘
4
  colorFrom: pink
5
  colorTo: blue
__init__.py ADDED
File without changes
__pycache__/model.cpython-39.pyc ADDED
Binary file (8.47 kB). View file
 
__pycache__/predict.cpython-39.pyc ADDED
Binary file (3.54 kB). View file
 
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import clip
3
+ from model import ClipCaptionModel
4
+ from transformers import GPT2Tokenizer
5
+ import numpy as np
6
+ import torch
7
+ import PIL
8
+ from predict import generate2, generate_beam
9
+ from huggingface_hub import hf_hub_download
10
+
11
+ D = torch.device
12
+ CPU = torch.device('cpu')
13
+ pretrained_model_variance = "0.015"
14
+ device = "cpu"
15
+ model_path = hf_hub_download('johko/capdec_015', 'model.pt')
16
+
17
+ clip_model, preprocess = clip.load("RN50x4", device=device, jit=False)
18
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
19
+
20
+ model_0 = hf_hub_download('johko/capdec_0', 'model.pt')
21
+ model_001 = hf_hub_download('johko/capdec_001', 'model.pt')
22
+ model_005 = hf_hub_download('johko/capdec_005', 'model.pt')
23
+ model_015 = hf_hub_download('johko/capdec_015', 'model.pt')
24
+ model_025 = hf_hub_download('johko/capdec_025', 'model.pt')
25
+ model_05 = hf_hub_download('johko/capdec_05', 'model.pt')
26
+
27
+
28
+ def load_noise_level_model(noise_level):
29
+ if noise_level == "0.0":
30
+ model_path = model_0
31
+ elif noise_level == "0.001":
32
+ model_path = model_001
33
+ elif noise_level == "0.005":
34
+ model_path = model_005
35
+ elif noise_level == "0.015":
36
+ model_path = model_015
37
+ elif noise_level == "0.025":
38
+ model_path = model_025
39
+ elif noise_level == "0.05":
40
+ model_path = model_05
41
+ else:
42
+ raise ValueError("Unknown Noise Level")
43
+
44
+ model = ClipCaptionModel()
45
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
46
+ model = model.eval()
47
+ model = model.to(device)
48
+
49
+ return model
50
+
51
+ def infer(input_image: np.ndarray, noise_level: str):
52
+ use_beam_search = True
53
+
54
+ model = load_noise_level_model(noise_level)
55
+
56
+ pil_image = PIL.Image.fromarray(input_image)
57
+
58
+ image = preprocess(pil_image).unsqueeze(0).to(device)
59
+ with torch.no_grad():
60
+ prefix = clip_model.encode_image(image).to(device, dtype=torch.float32)
61
+ prefix_embed = model.clip_project(prefix).reshape(1, 40, -1)
62
+ if use_beam_search:
63
+ generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
64
+ else:
65
+ generated_text_prefix = generate2(model, tokenizer, embed=prefix_embed)
66
+
67
+ return input_image, generated_text_prefix
68
+
69
+ description="""This space is a demo for the paper [*Text-Only Training for Image Captioning using Noise-Injected CLIP*](https://arxiv.org/pdf/2211.00575.pdf)
70
+ by David Nukrai, Ron Mokady and Amir Globerson.
71
+
72
+ The paper is about training an Image Captioning model by only using text. It leverages the usage of noise injections at different Noise Levels,
73
+ with which you can experiment as well in this demo. The text caption will change depending on the Noise Level you choose."""
74
+
75
+ dropdown = gr.components.Dropdown(["0.0", "0.001", "0.005", "0.015", "0.025", "0.05"], value="0.015", label="Noise Level")
76
+ input_image = gr.components.Image(label="Input Image")
77
+ output_image = gr.components.Image(label="Image")
78
+ output_text = gr.components.Textbox(label="Generated Caption")
79
+
80
+ iface = gr.Interface(
81
+ title="CapDec Image Captioning",
82
+ description=description,
83
+ fn=infer,
84
+ inputs=[input_image, dropdown],
85
+ outputs=[output_image, output_text],
86
+ examples=[["examples/flickr_ex2.jpg", "0.015"], ["examples/web_ex3.jpeg", "0.015"]])
87
+ iface.launch()
examples/flickr_ex1.jpg ADDED
examples/flickr_ex2.jpg ADDED
model.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch.nn.functional as nnf
3
+ from transformers import GPT2Tokenizer, GPT2LMHeadModel
4
+ import torch
5
+ from typing import Tuple, List, Union, Optional
6
+ import numpy as np
7
+
8
+
9
+ N = type(None)
10
+ V = np.array
11
+ ARRAY = np.ndarray
12
+ ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]]
13
+ VS = Union[Tuple[V, ...], List[V]]
14
+ VN = Union[V, N]
15
+ VNS = Union[VS, N]
16
+ T = torch.Tensor
17
+ TS = Union[Tuple[T, ...], List[T]]
18
+ TN = Optional[T]
19
+ TNS = Union[Tuple[TN, ...], List[TN]]
20
+ TSN = Optional[TS]
21
+ TA = Union[T, ARRAY]
22
+
23
+
24
+ class ClipCaptionModel(nn.Module):
25
+
26
+ def get_dummy_token(self, batch_size: int, device: torch.device) -> torch.Tensor:
27
+ return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device)
28
+
29
+ def forward(self, tokens: torch.Tensor, prefix: torch.Tensor, mask: Optional[torch.Tensor] = None,
30
+ labels: Optional[torch.Tensor] = None):
31
+ embedding_text = self.gpt.transformer.wte(tokens)
32
+ prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size)
33
+ embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1)
34
+ if labels is not None:
35
+ dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device)
36
+ labels = torch.cat((dummy_token, tokens), dim=1)
37
+ out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask)
38
+ return out
39
+
40
+ def __init__(self):
41
+ super(ClipCaptionModel, self).__init__()
42
+ self.prefix_length = 40
43
+ self.gpt = GPT2LMHeadModel.from_pretrained('gpt2')
44
+ self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
45
+ self.clip_project = TransformerMapper(640, self.gpt_embedding_size, 40,
46
+ 40, 8)
47
+
48
+
49
+
50
+ class MLP(nn.Module):
51
+
52
+ def forward(self, x: T) -> T:
53
+ return self.model(x)
54
+
55
+ def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh):
56
+ super(MLP, self).__init__()
57
+ layers = []
58
+ for i in range(len(sizes) -1):
59
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
60
+ if i < len(sizes) - 2:
61
+ layers.append(act())
62
+ self.model = nn.Sequential(*layers)
63
+
64
+
65
+ class ClipCaptionPrefix(ClipCaptionModel):
66
+
67
+ def parameters(self, recurse: bool = True):
68
+ return self.clip_project.parameters()
69
+
70
+ def train(self, mode: bool = True):
71
+ super(ClipCaptionPrefix, self).train(mode)
72
+ self.gpt.eval()
73
+ return self
74
+
75
+
76
+ class MlpTransformer(nn.Module):
77
+ def __init__(self, in_dim, h_dim, out_d: Optional[int] = None, act=nnf.relu, dropout=0.):
78
+ super().__init__()
79
+ out_d = out_d if out_d is not None else in_dim
80
+ self.fc1 = nn.Linear(in_dim, h_dim)
81
+ self.act = act
82
+ self.fc2 = nn.Linear(h_dim, out_d)
83
+ self.dropout = nn.Dropout(dropout)
84
+
85
+ def forward(self, x):
86
+ x = self.fc1(x)
87
+ x = self.act(x)
88
+ x = self.dropout(x)
89
+ x = self.fc2(x)
90
+ x = self.dropout(x)
91
+ return x
92
+
93
+
94
+ class MultiHeadAttention(nn.Module):
95
+
96
+ def __init__(self, dim_self, dim_ref, num_heads, bias=True, dropout=0.):
97
+ super().__init__()
98
+ self.num_heads = num_heads
99
+ head_dim = dim_self // num_heads
100
+ self.scale = head_dim ** -0.5
101
+ self.to_queries = nn.Linear(dim_self, dim_self, bias=bias)
102
+ self.to_keys_values = nn.Linear(dim_ref, dim_self * 2, bias=bias)
103
+ self.project = nn.Linear(dim_self, dim_self)
104
+ self.dropout = nn.Dropout(dropout)
105
+
106
+ def forward(self, x, y=None, mask=None):
107
+ y = y if y is not None else x
108
+ b, n, c = x.shape
109
+ _, m, d = y.shape
110
+ # b n h dh
111
+ queries = self.to_queries(x).reshape(b, n, self.num_heads, c // self.num_heads)
112
+ # b m 2 h dh
113
+ keys_values = self.to_keys_values(y).reshape(b, m, 2, self.num_heads, c // self.num_heads)
114
+ keys, values = keys_values[:, :, 0], keys_values[:, :, 1]
115
+ attention = torch.einsum('bnhd,bmhd->bnmh', queries, keys) * self.scale
116
+ if mask is not None:
117
+ if mask.dim() == 2:
118
+ mask = mask.unsqueeze(1)
119
+ attention = attention.masked_fill(mask.unsqueeze(3), float("-inf"))
120
+ attention = attention.softmax(dim=2)
121
+ out = torch.einsum('bnmh,bmhd->bnhd', attention, values).reshape(b, n, c)
122
+ out = self.project(out)
123
+ return out, attention
124
+
125
+
126
+ class TransformerLayer(nn.Module):
127
+
128
+ def forward_with_attention(self, x, y=None, mask=None):
129
+ x_, attention = self.attn(self.norm1(x), y, mask)
130
+ x = x + x_
131
+ x = x + self.mlp(self.norm2(x))
132
+ return x, attention
133
+
134
+ def forward(self, x, y=None, mask=None):
135
+ x = x + self.attn(self.norm1(x), y, mask)[0]
136
+ x = x + self.mlp(self.norm2(x))
137
+ return x
138
+
139
+ def __init__(self, dim_self, dim_ref, num_heads, mlp_ratio=4., bias=False, dropout=0., act=nnf.relu,
140
+ norm_layer: nn.Module = nn.LayerNorm):
141
+ super().__init__()
142
+ self.norm1 = norm_layer(dim_self)
143
+ self.attn = MultiHeadAttention(dim_self, dim_ref, num_heads, bias=bias, dropout=dropout)
144
+ self.norm2 = norm_layer(dim_self)
145
+ self.mlp = MlpTransformer(dim_self, int(dim_self * mlp_ratio), act=act, dropout=dropout)
146
+
147
+
148
+ class Transformer(nn.Module):
149
+
150
+ def forward_with_attention(self, x, y=None, mask=None):
151
+ attentions = []
152
+ for layer in self.layers:
153
+ x, att = layer.forward_with_attention(x, y, mask)
154
+ attentions.append(att)
155
+ return x, attentions
156
+
157
+ def forward(self, x, y=None, mask=None):
158
+ for i, layer in enumerate(self.layers):
159
+ if i % 2 == 0 and self.enc_dec: # cross
160
+ x = layer(x, y)
161
+ elif self.enc_dec: # self
162
+ x = layer(x, x, mask)
163
+ else: # self or cross
164
+ x = layer(x, y, mask)
165
+ return x
166
+
167
+ def __init__(self, dim_self: int, num_heads: int, num_layers: int, dim_ref: Optional[int] = None,
168
+ mlp_ratio: float = 2., act=nnf.relu, norm_layer: nn.Module = nn.LayerNorm, enc_dec: bool = False):
169
+ super(Transformer, self).__init__()
170
+ dim_ref = dim_ref if dim_ref is not None else dim_self
171
+ self.enc_dec = enc_dec
172
+ if enc_dec:
173
+ num_layers = num_layers * 2
174
+ layers = []
175
+ for i in range(num_layers):
176
+ if i % 2 == 0 and enc_dec: # cross
177
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
178
+ elif enc_dec: # self
179
+ layers.append(TransformerLayer(dim_self, dim_self, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
180
+ else: # self or cross
181
+ layers.append(TransformerLayer(dim_self, dim_ref, num_heads, mlp_ratio, act=act, norm_layer=norm_layer))
182
+ self.layers = nn.ModuleList(layers)
183
+
184
+
185
+ class TransformerMapper(nn.Module):
186
+
187
+ def forward(self, x):
188
+ x = self.linear(x).view(x.shape[0], self.clip_length, -1)
189
+ prefix = self.prefix_const.unsqueeze(0).expand(x.shape[0], *self.prefix_const.shape)
190
+ prefix = torch.cat((x, prefix), dim=1)
191
+ out = self.transformer(prefix)[:, self.clip_length:]
192
+ return out
193
+
194
+ def __init__(self, dim_clip: int, dim_embedding: int, prefix_length: int, clip_length: int, num_layers: int = 8):
195
+ super(TransformerMapper, self).__init__()
196
+ self.clip_length = clip_length
197
+ self.transformer = Transformer(dim_embedding, 8, num_layers)
198
+ self.linear = nn.Linear(dim_clip, clip_length * dim_embedding)
199
+ self.prefix_const = nn.Parameter(torch.randn(prefix_length, dim_embedding), requires_grad=True)
predict.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, List, Union, Optional
3
+ import numpy as np
4
+
5
+
6
+ def generate_beam(model, tokenizer, beam_size: int = 5, prompt=None, embed=None,
7
+ entry_length=67, temperature=1., stop_token: str = '.'):
8
+
9
+ model.eval()
10
+ stop_token_index = tokenizer.encode(stop_token)[0]
11
+ tokens = None
12
+ scores = None
13
+ device = next(model.parameters()).device
14
+ seq_lengths = torch.ones(beam_size, device=device)
15
+ is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
16
+ with torch.no_grad():
17
+ if embed is not None:
18
+ generated = embed
19
+ else:
20
+ if tokens is None:
21
+ tokens = torch.tensor(tokenizer.encode(prompt))
22
+ tokens = tokens.unsqueeze(0).to(device)
23
+ generated = model.gpt.transformer.wte(tokens)
24
+ for i in range(entry_length):
25
+ outputs = model.gpt(inputs_embeds=generated)
26
+ logits = outputs.logits
27
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
28
+ logits = logits.softmax(-1).log()
29
+ if scores is None:
30
+ scores, next_tokens = logits.topk(beam_size, -1)
31
+ generated = generated.expand(beam_size, *generated.shape[1:])
32
+ next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
33
+ if tokens is None:
34
+ tokens = next_tokens
35
+ else:
36
+ tokens = tokens.expand(beam_size, *tokens.shape[1:])
37
+ tokens = torch.cat((tokens, next_tokens), dim=1)
38
+ else:
39
+ logits[is_stopped] = -float(np.inf)
40
+ logits[is_stopped, 0] = 0
41
+ scores_sum = scores[:, None] + logits
42
+ seq_lengths[~is_stopped] += 1
43
+ scores_sum_average = scores_sum / seq_lengths[:, None]
44
+ scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
45
+ next_tokens_source = next_tokens // scores_sum.shape[1]
46
+ seq_lengths = seq_lengths[next_tokens_source]
47
+ next_tokens = next_tokens % scores_sum.shape[1]
48
+ next_tokens = next_tokens.unsqueeze(1)
49
+ tokens = tokens[next_tokens_source]
50
+ tokens = torch.cat((tokens, next_tokens), dim=1)
51
+ generated = generated[next_tokens_source]
52
+ scores = scores_sum_average * seq_lengths
53
+ is_stopped = is_stopped[next_tokens_source]
54
+ next_token_embed = model.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
55
+ generated = torch.cat((generated, next_token_embed), dim=1)
56
+ is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
57
+ if is_stopped.all():
58
+ break
59
+ scores = scores / seq_lengths
60
+ output_list = tokens.cpu().numpy()
61
+ output_texts = [tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
62
+ order = scores.argsort(descending=True)
63
+ output_texts = [output_texts[i] for i in order]
64
+ return output_texts
65
+
66
+
67
+ def generate2(
68
+ model,
69
+ tokenizer,
70
+ tokens=None,
71
+ prompt=None,
72
+ embed=None,
73
+ entry_count=1,
74
+ entry_length=67, # maximum number of words
75
+ top_p=0.8,
76
+ temperature=1.,
77
+ stop_token: str = '.',
78
+ ):
79
+ model.eval()
80
+ generated_num = 0
81
+ generated_list = []
82
+ stop_token_index = tokenizer.encode(stop_token)[0]
83
+ filter_value = -float("Inf")
84
+ device = next(model.parameters()).device
85
+
86
+ with torch.no_grad():
87
+
88
+ for entry_idx in trange(entry_count):
89
+ if embed is not None:
90
+ generated = embed
91
+ else:
92
+ if tokens is None:
93
+ tokens = torch.tensor(tokenizer.encode(prompt))
94
+ tokens = tokens.unsqueeze(0).to(device)
95
+
96
+ generated = model.gpt.transformer.wte(tokens)
97
+
98
+ for i in range(entry_length):
99
+
100
+ outputs = model.gpt(inputs_embeds=generated)
101
+ logits = outputs.logits
102
+ logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
103
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
104
+ cumulative_probs = torch.cumsum(nnf.softmax(sorted_logits, dim=-1), dim=-1)
105
+ sorted_indices_to_remove = cumulative_probs > top_p
106
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
107
+ ..., :-1
108
+ ].clone()
109
+ sorted_indices_to_remove[..., 0] = 0
110
+
111
+ indices_to_remove = sorted_indices[sorted_indices_to_remove]
112
+ logits[:, indices_to_remove] = filter_value
113
+ next_token = torch.argmax(logits, -1).unsqueeze(0)
114
+ next_token_embed = model.gpt.transformer.wte(next_token)
115
+ if tokens is None:
116
+ tokens = next_token
117
+ else:
118
+ tokens = torch.cat((tokens, next_token), dim=1)
119
+ generated = torch.cat((generated, next_token_embed), dim=1)
120
+ if stop_token_index == next_token.item():
121
+ break
122
+
123
+ output_list = list(tokens.squeeze().cpu().numpy())
124
+ output_text = tokenizer.decode(output_list)
125
+ generated_list.append(output_text)
126
+
127
+ return generated_list[0]
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ git+https://github.com/openai/CLIP.git@main
2
+ transformers