kangaroo / data_utils.py
WEBing's picture
fix eva in data_utils
1f1db3a
raw
history blame contribute delete
No virus
5.43 kB
import decord
import random
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize
def _convert_to_rgb(image):
return image.convert('RGB')
def image_transform(image_size: int):
mean = (0.48145466, 0.4578275, 0.40821073)
std = (0.26862954, 0.26130258, 0.27577711)
normalize = Normalize(mean=mean, std=std)
transforms = [
Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC),
_convert_to_rgb,
ToTensor(),
normalize,
]
return Compose(transforms)
def preprocess_multimodal(sources, num_segments):
for source in sources:
for sentence in source:
X_token = '<video>'
if X_token in sentence['content']:
replace_token = ""
ns = num_segments
ns = ns // 2 - 1
for _ in range(ns):
replace_token += "<image>"
replace_token += "<eof>"
replace_token += "<image>"
replace_token += "<eov>"
replace_token = '<vi_start>' + replace_token + '<vi_end>'
sentence["content"] = sentence["content"].replace(X_token, replace_token)
return sources
def preprocess(
sources,
tokenizer,
s_id=None,
):
en_qa_templates = [
"Review the given video and answer the question associated with its visual elements.",
"Watch the provided video and offer an accurate response to the related question.",
"Scrutinize the video carefully, identifying relevant details in order to address the linked question.",
"Take a close look at the presented visuals and deliver a precise answer to the corresponding question.",
"Observe the video attentively and accurately respond to the associated question.",
"View the video attentively and provide a suitable answer to the posed question.",
"Examine the video and approach the connected question with an informed response.",
"Assess the displayed video and answer the subsequent question with accuracy.",
"Consider the video content and deliver a relevant answer to the corresponding question.",
"Go through the video, taking into account key aspects, and respond to the question."
]
ch_qa_templates = [
"审阅所提供的视频,并回答与其视觉元素相关的问题。",
"观看所提供的视频,对相关问题给出准确的回答。",
"仔细审查视频,识别相关的细节,回答与之相关的问题。",
"仔细观察所展示的视觉内容,并对相应的问题给出精确的回答。",
"认真观察视频并准确回答相关的问题。",
"详细观看视频,并且对提出的问题给出合适的回答。",
"观察视频并用有依据的回答来解答相关的问题。",
"评估展示的视频,并准确地回答随后的问题。",
"根据视频内容,对相应的问题给出合理的答案。",
"浏览视频,根据其中的关键内容回答问题。",
]
if s_id != None:
index = s_id
else:
index = random.choice(range(len(en_qa_templates)))
system_prompt = f"""You are a helpful assistant, {en_qa_templates[index]} 你是一个乐于助人的助手,{ch_qa_templates[index]}"""
messages = []
for source in sources:
message = [{'role': 'system', 'content': system_prompt}]
for sentence in source:
message.append(sentence)
messages.append(message)
input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt')
return input_ids
def get_index(fps, max_frame, num_segments):
num_frames = max_frame
if num_frames <= num_segments:
out_indices = np.array([(idx % num_frames) for idx in range(num_segments)])
out_indices = np.sort(out_indices)
else:
out_indices = np.linspace(0, num_frames-1, num_segments)
durations = [idx.item() / fps for idx in out_indices]
return out_indices.astype(np.int64), durations
def read_video(video_path, num_segments):
image_processor = image_transform(image_size=448)
vr = decord.VideoReader(video_path)
fps = float(vr.get_avg_fps())
frame_indices, durations = get_index(fps, len(vr) - 1, num_segments)
video = []
for frame_index in frame_indices:
image = Image.fromarray(vr[frame_index].asnumpy())
video.append(image_processor(image).unsqueeze(0))
video = torch.concat(video)
return video, torch.Tensor(durations)
def get_input(video_path, num_segments, question, history, tokenizer, s_id):
video, durations = read_video(video_path, num_segments)
if history == None:
conversations = []
conversations.append({'role': 'user', 'content': f'<video>\n{question}'})
else:
conversations = history
conversations.append({'role': 'user', 'content': question})
sources = [conversations]
sources = preprocess_multimodal(sources, video.shape[0])
input_ids = preprocess(sources, tokenizer, s_id=s_id)
return video, durations, input_ids, conversations
def add_pred_to_history(history, pred):
history.append({'role': 'assistant', 'content': pred})
return history