import json import os import torch import argparse from PIL import Image from chameleon.inference.chameleon import ChameleonInferenceModel, Options from constants import ( MODEL_7B_PATH, TOKENIZER_TEXT_PATH, TOKENIZER_IMAGE_CFG_PATH, TOKENIZER_IMAGE_PATH, ) from typing import List, Tuple import logging # Set up the logging configuration logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def split_token_sequence( tokens: torch.LongTensor, boi: int, eoi: int ) -> List[Tuple[str, torch.LongTensor]]: """ Split a sequence of tokens into text and image segments. Args: tokens (torch.LongTensor): The token sequence. boi (int): Begin of image token. eoi (int): End of image token. Returns: List[Tuple[str, torch.LongTensor]]: List of tuples indicating segment type and tokens. """ batch_size, _ = tokens.shape assert batch_size == 1, "Batch size must be 1" device = tokens.device tokens = tokens[0] # remove batch dimension tokens = tokens.to(device) segments = [] current_segment = [] in_image_seg = False for token in tokens: if token == boi: # if entering an image segment, save the current text segment (if any) if current_segment: segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) current_segment = [] in_image_seg = True elif token == eoi and in_image_seg: # if exiting an image segment, save the current image segment segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) current_segment = [] in_image_seg = False else: current_segment.append(token) # save any remaining tokens if current_segment: if in_image_seg: segments.append(("image_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) else: segments.append(("text_seg", torch.tensor(current_segment, dtype=tokens.dtype, device=device).reshape(1, -1))) return segments def main(args: argparse.Namespace): """Main function to generate and process model output.""" # Load Chameleon model model = ChameleonInferenceModel( MODEL_7B_PATH.as_posix(), TOKENIZER_TEXT_PATH.as_posix(), TOKENIZER_IMAGE_CFG_PATH.as_posix(), TOKENIZER_IMAGE_PATH.as_posix(), ) # Print model configuration logging.info(f"Model path: {MODEL_7B_PATH}") logging.info(f"Text tokenizer path: {TOKENIZER_TEXT_PATH}") logging.info(f"Image tokenizer config path: {TOKENIZER_IMAGE_CFG_PATH}") logging.info(f"Image tokenizer path: {TOKENIZER_IMAGE_PATH}") # Generate options options = Options() # Prepare prompt instructions = [args.instruction] batch_prompt_ui = [] for instruction in instructions: if isinstance(instruction, Tuple): inst, image_path = instruction batch_prompt_ui += [ [ {"type": "image", "value": f"file:{image_path}"}, {"type": "text", "value": inst} ], ] else: batch_prompt_ui += [ [ {"type": "text", "value": instruction} ], ] # generate tokens: torch.LongTensor = model.generate( batch_prompt_ui=batch_prompt_ui, options=options ) # split boi, eoi = model.vocab.begin_image, model.vocab.end_image # 8197(boi), 8196(eoi) segments = split_token_sequence(tokens, boi, eoi) # decode os.makedirs(args.save_dir, exist_ok=True) segments_data = [] for seg_id, (seg_type, seg_tokens) in enumerate(segments): if seg_type == "image_seg": assert seg_tokens.shape[1] == 1024 img = model.decode_image(seg_tokens)[0] image_path = os.path.join(args.save_dir, f"{seg_id}.png") img.save(image_path) segments_data.append({"type": "image", "content": image_path}) else: assert seg_type == "text_seg" decoded_text = model.decode_text(seg_tokens)[0] segments_data.append({"type": "text", "content": decoded_text}) jsonl_path = os.path.join("./segments.jsonl") with open(jsonl_path, 'w') as jsonl_file: for segment in segments_data: jsonl_file.write(json.dumps(segment) + '\n') def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser(description="Generate interleaved image-text content based on text instructions.") parser.add_argument("-i", "--instruction", type=str, required=True, help="The instruction for interleaved image-text generation.") parser.add_argument("-s", "--save_dir", type=str, default="./outputs/interleaved/", help="The directory to save the generated images.") args: argparse.Namespace = parser.parse_args() return args if __name__ == "__main__": args: argparse.Namespace = parse_arguments() main(args)