Edit model card

The model is CogAgent-chat-18B finetuned (LoRA with rank 8 added to the language decoder) on 160K WebSight examples.

The model is in the format of SAT (SwissArmyTransformer).

Please refer to our paper and our codebase to run inference.

Use of the model must comply with the original model license and the original data license (CC-BY-4.0).

Example Usage (based on SAT)

import sys
sys.path.insert(1, '/path/to/CogVLM')
from sat.model import AutoModel
import argparse
from utils.models import CogAgentModel, CogVLMModel, FineTuneTestCogAgentModel
import torch
from sat.model.mixins import CachedAutoregressiveMixin
from sat.quantization.kernels import quantize
from sat.model import AutoModel
from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
from utils.models import CogAgentModel, CogVLMModel
from tqdm import tqdm 
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--temperature', type=float, default=0.5)
parser.add_argument('--repetition_penalty', type=float, default=1.1)
args = parser.parse_args()
args.bf16 = True
args.stream_chat = False
args.version = "chat"

# You can download the testset from https://huggingface.co/datasets/SALT-NLP/Design2Code
test_data_dir = "/path/to/Design2Code"
predictions_dir = "/path/to/design2code_18b_v0_predictions"
if not os.path.exists(predictions_dir):
    try:
        os.makedirs(predictions_dir)
    except:
        pass

filename_list = [filename for filename in os.listdir(test_data_dir) if filename.endswith(".png")]
world_size = 1
model, model_args = FineTuneTestCogAgentModel.from_pretrained(
        f"/path/to/design2code-18b-v0",
        args=argparse.Namespace(
        deepspeed=None,
        local_rank=0,
        rank=0,
        world_size=1,
        model_parallel_size=1,
        mode='inference',
        skip_init=True,
        use_gpu_initialization=True,
        device='cuda',
        bf16=True,
        fp16=None), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
model = model.eval()
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())

language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
print("[Language processor version]:", language_processor_version)
tokenizer = llama2_tokenizer("lmsys/vicuna-7b-v1.5", signal_type=language_processor_version)
image_processor = get_image_processor(model_args.eva_args["image_size"][0])
cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
text_processor_infer = llama2_text_processor_inference(tokenizer, 2048, model.image_length)

def get_html(image_path):
    with torch.no_grad():
        history = None
        cache_image = None
        # We use an empty string as the query
        query = ''
    
        response, history, cache_image = chat(
            image_path,
            model,
            text_processor_infer,
            image_processor,
            query,
            history=history,
            cross_img_processor=cross_image_processor,
            image=cache_image,
            max_length=4096,
            top_p=1.0,
            temperature=args.temperature,
            top_k=1,
            invalid_slices=text_processor_infer.invalid_slices,
            repetition_penalty=args.repetition_penalty,
            args=args
        )
    
    return response

for filename in tqdm(filename_list):
    image_path = os.path.join(test_data_dir, filename)
    generated_text = get_html(image_path)
    with open(os.path.join(predictions_dir, filename.replace(".png", ".html")), "w", encoding='utf-8') as f:
        f.write(generated_text)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference API
Unable to determine this model's library. Check the docs .

Dataset used to train SALT-NLP/Design2Code-18B-v0