Update README.md
Browse files
README.md
CHANGED
@@ -23,8 +23,42 @@ Visual Instruction Tuning Script: https://github.com/amith-ananthram/mLLaVA/blob
|
|
23 |
|
24 |
Usage Example:
|
25 |
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
tokenizer = AutoTokenizer.from_pretrained('baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True)
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
Usage Example:
|
25 |
|
26 |
+
import torch
|
27 |
+
from PIL import Image
|
28 |
+
from transformers import AutoTokenizer, AutoModelForVisualQuestionAnswering
|
29 |
+
|
30 |
+
# from constants.py, utils.py, included as files in this HF release
|
31 |
+
from constants import IMAGE_TOKEN_INDEX
|
32 |
+
from utils import tokenizer_image_token, process_images
|
33 |
|
34 |
+
device = torch.device('cuda')
|
35 |
+
|
36 |
+
# load model and vision tower
|
37 |
+
model = AutoModelForVisualQuestionAnswering.from_pretrained('amitha/mllava.baichuan2-en', trust_remote_code=True)
|
38 |
+
model.model.vision_tower.load_model()
|
39 |
+
model = model.eval().to(device)
|
40 |
+
|
41 |
+
image_processor = model.get_vision_tower().image_processor
|
42 |
tokenizer = AutoTokenizer.from_pretrained('baichuan-inc/Baichuan2-7B-Chat', trust_remote_code=True)
|
43 |
+
|
44 |
+
prompt = '<reserved_106><image>\nPlease describe this image.<reserved_107>'
|
45 |
+
|
46 |
+
input_ids = tokenizer_image_token(
|
47 |
+
prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
|
48 |
+
)
|
49 |
+
with Image.open("path/to/image.png") as img:
|
50 |
+
images = process_images(
|
51 |
+
[img.convert('RGB')], image_processor, model.config
|
52 |
+
).to(dtype=torch.float16)
|
53 |
+
image_sizes = [img.size]
|
54 |
+
|
55 |
+
with torch.no_grad():
|
56 |
+
output = model.generate(
|
57 |
+
inputs=input_ids.unsqueeze(dim=0).to(device),
|
58 |
+
attention_mask=torch.ones(input_ids.shape[0]).unsqueeze(dim=0).to(device),
|
59 |
+
images=images.to(device),
|
60 |
+
image_sizes=image_sizes
|
61 |
+
)
|
62 |
+
|
63 |
+
print(tokenizer.batch_decode(output, skip_special_tokens=True))
|
64 |
+
|