amitha commited on
Commit
66b3372
1 Parent(s): 3e99979

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -3
README.md CHANGED
@@ -23,8 +23,42 @@ Visual Instruction Tuning Script: https://github.com/amith-ananthram/mLLaVA/blob
23
 
24
  Usage Example:
25
 
26
- from transformers import AutoProcessor, AutoTokenizer, AutoModelForVisualQuestionAnswering
 
 
 
 
 
 
27
 
28
- processor = AutoProcessor.from_pretrained('openai/clip-vit-large-patch14-336')
 
 
 
 
 
 
 
29
  tokenizer = AutoTokenizer.from_pretrained('baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True)
30
- model = AutoModelForVisualQuestionAnswering.from_pretrained('amitha/mllava.baichuan2-en', trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  Usage Example:
25
 
26
+ import torch
27
+ from PIL import Image
28
+ from transformers import AutoTokenizer, AutoModelForVisualQuestionAnswering
29
+
30
+ # from constants.py, utils.py, included as files in this HF release
31
+ from constants import IMAGE_TOKEN_INDEX
32
+ from utils import tokenizer_image_token, process_images
33
 
34
+ device = torch.device('cuda')
35
+
36
+ # load model and vision tower
37
+ model = AutoModelForVisualQuestionAnswering.from_pretrained('amitha/mllava.baichuan2-en', trust_remote_code=True)
38
+ model.model.vision_tower.load_model()
39
+ model = model.eval().to(device)
40
+
41
+ image_processor = model.get_vision_tower().image_processor
42
  tokenizer = AutoTokenizer.from_pretrained('baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True)
43
+
44
+ prompt = '<reserved_106><image>\nPlease describe this image.<reserved_107>'
45
+
46
+ input_ids = tokenizer_image_token(
47
+ prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
48
+ )
49
+ with Image.open("path/to/image.png") as img:
50
+ images = process_images(
51
+ [img.convert('RGB')], image_processor, model.config
52
+ ).to(dtype=torch.float16)
53
+ image_sizes = [img.size]
54
+
55
+ with torch.no_grad():
56
+ output = model.generate(
57
+ inputs=input_ids.unsqueeze(dim=0).to(device),
58
+ attention_mask=torch.ones(input_ids.shape[0]).unsqueeze(dim=0).to(device),
59
+ images=images.to(device),
60
+ image_sizes=image_sizes
61
+ )
62
+
63
+ print(tokenizer.batch_decode(output, skip_special_tokens=True))
64
+