Splend1dchan commited on
Commit
cbf5f57
1 Parent(s): 65dd944

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -0
README.md ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ This model weight is identical to laion/CLIP-ViT-H-14-laion2B-s32B-b79K, but with the ViT component only.
2
+ This is to support loading the model as a ClipModel. As a failed to load the original model using AutoModel (feedback appreciated)
3
+ With this distribution, I was finally able to load from AutoModel, and further support image classification tasks using my self-defined class CLIPViTForImageClassification listed below.
4
+ However, there is still a small issue that I cannot resolve, I can only load the model if I git clone this repo to local, if I load from web, the loading still fails.
5
+ ```python
6
+ from transformers.models.clip.modeling_clip import CLIPPreTrainedModel, CLIPConfig, CLIPVisionTransformer
7
+ from transformers.modeling_outputs import (
8
+ BaseModelOutput,
9
+ BaseModelOutputWithPooling,
10
+ ImageClassifierOutput,
11
+ MaskedImageModelingOutput,
12
+ )
13
+ from typing import Dict, List, Optional, Set, Tuple, Union
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
17
+
18
+
19
+ class CLIPViTForImageClassification(CLIPPreTrainedModel):
20
+ def __init__(self, config: CLIPConfig) -> None:
21
+ super().__init__(config)
22
+
23
+ self.num_labels = config.num_labels
24
+ vision_config = config.vision_config
25
+ self.vision_model = CLIPVisionTransformer(vision_config)
26
+
27
+ # Classifier head
28
+
29
+ self.classifier = nn.Linear(vision_config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ def forward(
35
+ self,
36
+ pixel_values: Optional[torch.Tensor] = None,
37
+ #head_mask: Optional[torch.Tensor] = None,
38
+ labels: Optional[torch.Tensor] = None,
39
+ output_attentions: Optional[bool] = None,
40
+ output_hidden_states: Optional[bool] = None,
41
+ #interpolate_pos_encoding: Optional[bool] = None,
42
+ return_dict: Optional[bool] = None,
43
+ ) -> Union[tuple, ImageClassifierOutput]:
44
+ r"""
45
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
46
+ Labels for computing the image classification/regression loss. Indices should be in `[0, ...,
47
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
48
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
49
+ """
50
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
51
+
52
+ outputs = self.vision_model(
53
+ pixel_values,
54
+ #head_mask=head_mask,
55
+ output_attentions=output_attentions,
56
+ output_hidden_states=output_hidden_states,
57
+ #interpolate_pos_encoding=interpolate_pos_encoding,
58
+ return_dict=return_dict,
59
+ )
60
+
61
+ sequence_output = outputs[0]
62
+
63
+ logits = self.classifier(sequence_output[:, 0, :])
64
+
65
+ loss = None
66
+ if labels is not None:
67
+ # move labels to correct device to enable model parallelism
68
+ labels = labels.to(logits.device)
69
+ if self.config.problem_type is None:
70
+ if self.num_labels == 1:
71
+ self.config.problem_type = "regression"
72
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
73
+ self.config.problem_type = "single_label_classification"
74
+ else:
75
+ self.config.problem_type = "multi_label_classification"
76
+
77
+ if self.config.problem_type == "regression":
78
+ loss_fct = MSELoss()
79
+ if self.num_labels == 1:
80
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
81
+ else:
82
+ loss = loss_fct(logits, labels)
83
+ elif self.config.problem_type == "single_label_classification":
84
+ loss_fct = CrossEntropyLoss()
85
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
86
+ elif self.config.problem_type == "multi_label_classification":
87
+ loss_fct = BCEWithLogitsLoss()
88
+ loss = loss_fct(logits, labels)
89
+
90
+ if not return_dict:
91
+ output = (logits,) + outputs[1:]
92
+ return ((loss,) + output) if loss is not None else output
93
+
94
+ return ImageClassifierOutput(
95
+ loss=loss,
96
+ logits=logits,
97
+ hidden_states=outputs.hidden_states,
98
+ attentions=outputs.attentions,
99
+ )
100
+ ```