Spaces:
Runtime error
Runtime error
# This code is adapted from https://github.com/THUDM/CogView2/blob/4e55cce981eb94b9c8c1f19ba9f632fd3ee42ba8/cogview2_text2image.py | |
from __future__ import annotations | |
import argparse | |
import functools | |
import logging | |
import os | |
import pathlib | |
import random | |
import subprocess | |
import sys | |
import time | |
import zipfile | |
from typing import Any | |
if os.getenv('SYSTEM') == 'spaces': | |
subprocess.run('pip install icetk==0.0.3'.split()) | |
subprocess.run('pip install SwissArmyTransformer==0.2.4'.split()) | |
subprocess.run( | |
'pip install git+https://github.com/Sleepychord/Image-Local-Attention@43fee31' | |
.split()) | |
#subprocess.run('git clone https://github.com/NVIDIA/apex'.split()) | |
#subprocess.run('git checkout 1403c21'.split(), cwd='apex') | |
#with open('patch.apex') as f: | |
# subprocess.run('patch -p1'.split(), cwd='apex', stdin=f) | |
#subprocess.run( | |
# 'pip install -v --disable-pip-version-check --no-cache-dir --global-option --cpp_ext --global-option --cuda_ext ./' | |
# .split(), | |
# cwd='apex') | |
#subprocess.run('rm -rf apex'.split()) | |
with open('patch') as f: | |
subprocess.run('patch -p1'.split(), cwd='CogView2', stdin=f) | |
from huggingface_hub import hf_hub_download | |
def download_and_extract_icetk_models() -> None: | |
icetk_model_dir = pathlib.Path('/home/user/.icetk_models') | |
icetk_model_dir.mkdir() | |
path = hf_hub_download('THUDM/icetk', | |
'models.zip', | |
use_auth_token=os.getenv('HF_TOKEN')) | |
with zipfile.ZipFile(path) as f: | |
f.extractall(path=icetk_model_dir.as_posix()) | |
def download_and_extract_cogview2_models(name: str) -> None: | |
path = hf_hub_download('THUDM/CogView2', | |
name, | |
use_auth_token=os.getenv('HF_TOKEN')) | |
with zipfile.ZipFile(path) as f: | |
f.extractall() | |
os.remove(path) | |
download_and_extract_icetk_models() | |
names = [ | |
'coglm.zip', | |
'cogview2-dsr.zip', | |
'cogview2-itersr.zip', | |
] | |
for name in names: | |
download_and_extract_cogview2_models(name) | |
os.environ['SAT_HOME'] = '/home/user/app/sharefs/cogview-new' | |
import gradio as gr | |
import numpy as np | |
import torch | |
from icetk import icetk as tokenizer | |
from SwissArmyTransformer import get_args | |
from SwissArmyTransformer.arguments import set_random_seed | |
from SwissArmyTransformer.generation.autoregressive_sampling import \ | |
filling_sequence | |
from SwissArmyTransformer.model import CachedAutoregressiveModel | |
app_dir = pathlib.Path(__file__).parent | |
submodule_dir = app_dir / 'CogView2' | |
sys.path.insert(0, submodule_dir.as_posix()) | |
from coglm_strategy import CoglmStrategy | |
from sr_pipeline import SRGroup | |
formatter = logging.Formatter( | |
'[%(asctime)s] %(name)s %(levelname)s: %(message)s', | |
datefmt='%Y-%m-%d %H:%M:%S') | |
stream_handler = logging.StreamHandler(stream=sys.stdout) | |
stream_handler.setLevel(logging.INFO) | |
stream_handler.setFormatter(formatter) | |
logger = logging.getLogger(__name__) | |
logger.setLevel(logging.INFO) | |
logger.propagate = False | |
logger.addHandler(stream_handler) | |
tokenizer.add_special_tokens( | |
['<start_of_image>', '<start_of_english>', '<start_of_chinese>']) | |
def get_masks_and_position_ids_coglm( | |
seq: torch.Tensor, context_length: int | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
tokens = seq.unsqueeze(0) | |
attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) | |
attention_mask.tril_() | |
attention_mask[..., :context_length] = 1 | |
attention_mask.unsqueeze_(1) | |
position_ids = torch.zeros(len(seq), | |
device=tokens.device, | |
dtype=torch.long) | |
torch.arange(0, context_length, out=position_ids[:context_length]) | |
torch.arange(512, | |
512 + len(seq) - context_length, | |
out=position_ids[context_length:]) | |
position_ids = position_ids.unsqueeze(0) | |
return tokens, attention_mask, position_ids | |
class InferenceModel(CachedAutoregressiveModel): | |
def final_forward(self, logits, **kwargs): | |
logits_parallel = logits | |
logits_parallel = torch.nn.functional.linear( | |
logits_parallel.float(), | |
self.transformer.word_embeddings.weight[:20000].float()) | |
return logits_parallel | |
def get_recipe(name: str) -> dict[str, Any]: | |
r = { | |
'attn_plus': 1.4, | |
'temp_all_gen': 1.15, | |
'topk_gen': 16, | |
'temp_cluster_gen': 1., | |
'temp_all_dsr': 1.5, | |
'topk_dsr': 100, | |
'temp_cluster_dsr': 0.89, | |
'temp_all_itersr': 1.3, | |
'topk_itersr': 16, | |
'query_template': '{}<start_of_image>', | |
} | |
if name == 'none': | |
pass | |
elif name == 'mainbody': | |
r['query_template'] = '{} 高清摄影 隔绝<start_of_image>' | |
elif name == 'photo': | |
r['query_template'] = '{} 高清摄影<start_of_image>' | |
elif name == 'flat': | |
r['query_template'] = '{} 平面风格<start_of_image>' | |
# r['attn_plus'] = 1.8 | |
# r['temp_cluster_gen'] = 0.75 | |
r['temp_all_gen'] = 1.1 | |
r['topk_dsr'] = 5 | |
r['temp_cluster_dsr'] = 0.4 | |
r['temp_all_itersr'] = 1 | |
r['topk_itersr'] = 5 | |
elif name == 'comics': | |
r['query_template'] = '{} 漫画 隔绝<start_of_image>' | |
r['topk_dsr'] = 5 | |
r['temp_cluster_dsr'] = 0.4 | |
r['temp_all_gen'] = 1.1 | |
r['temp_all_itersr'] = 1 | |
r['topk_itersr'] = 5 | |
elif name == 'oil': | |
r['query_template'] = '{} 油画风格<start_of_image>' | |
pass | |
elif name == 'sketch': | |
r['query_template'] = '{} 素描风格<start_of_image>' | |
r['temp_all_gen'] = 1.1 | |
elif name == 'isometric': | |
r['query_template'] = '{} 等距矢量图<start_of_image>' | |
r['temp_all_gen'] = 1.1 | |
elif name == 'chinese': | |
r['query_template'] = '{} 水墨国画<start_of_image>' | |
r['temp_all_gen'] = 1.12 | |
elif name == 'watercolor': | |
r['query_template'] = '{} 水彩画风格<start_of_image>' | |
return r | |
def get_default_args() -> argparse.Namespace: | |
arg_list = ['--mode', 'inference', '--fp16'] | |
args = get_args(arg_list) | |
known = argparse.Namespace(img_size=160, | |
only_first_stage=False, | |
inverse_prompt=False, | |
style='mainbody') | |
args = argparse.Namespace(**vars(args), **vars(known), | |
**get_recipe(known.style)) | |
return args | |
class Model: | |
def __init__(self, | |
max_inference_batch_size: int, | |
only_first_stage: bool = False): | |
self.args = get_default_args() | |
self.args.only_first_stage = only_first_stage | |
self.args.max_inference_batch_size = max_inference_batch_size | |
self.model, self.args = self.load_model() | |
self.strategy = self.load_strategy() | |
self.srg = self.load_srg() | |
self.query_template = self.args.query_template | |
self.style = self.args.style | |
self.device = torch.device(self.args.device) | |
self.fp16 = self.args.fp16 | |
self.max_batch_size = self.args.max_inference_batch_size | |
self.only_first_stage = self.args.only_first_stage | |
def load_model(self) -> tuple[InferenceModel, argparse.Namespace]: | |
logger.info('--- load_model ---') | |
start = time.perf_counter() | |
model, args = InferenceModel.from_pretrained(self.args, 'coglm') | |
if not self.args.only_first_stage: | |
model.transformer.cpu() | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return model, args | |
def load_strategy(self) -> CoglmStrategy: | |
logger.info('--- load_strategy ---') | |
start = time.perf_counter() | |
invalid_slices = [slice(tokenizer.num_image_tokens, None)] | |
strategy = CoglmStrategy(invalid_slices, | |
temperature=self.args.temp_all_gen, | |
top_k=self.args.topk_gen, | |
top_k_cluster=self.args.temp_cluster_gen) | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return strategy | |
def load_srg(self) -> SRGroup: | |
logger.info('--- load_srg ---') | |
start = time.perf_counter() | |
srg = None if self.args.only_first_stage else SRGroup(self.args) | |
if srg is not None: | |
srg.dsr.max_bz = 2 | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return srg | |
def update_style(self, style: str) -> None: | |
if style == self.style: | |
return | |
logger.info('--- update_style ---') | |
start = time.perf_counter() | |
self.style = style | |
self.args = argparse.Namespace(**(vars(self.args) | get_recipe(style))) | |
self.query_template = self.args.query_template | |
logger.debug(f'{self.query_template=}') | |
self.strategy.temperature = self.args.temp_all_gen | |
if self.srg is not None: | |
self.srg.dsr.strategy.temperature = self.args.temp_all_dsr | |
self.srg.dsr.strategy.topk = self.args.topk_dsr | |
self.srg.dsr.strategy.temperature2 = self.args.temp_cluster_dsr | |
self.srg.itersr.strategy.temperature = self.args.temp_all_itersr | |
self.srg.itersr.strategy.topk = self.args.topk_itersr | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
def run(self, text: str, style: str, seed: int, only_first_stage: bool, | |
num: int) -> list[np.ndarray] | None: | |
logger.info('==================== run ====================') | |
start = time.perf_counter() | |
self.update_style(style) | |
set_random_seed(seed) | |
seq, txt_len = self.preprocess_text(text) | |
if seq is None: | |
return None | |
self.only_first_stage = only_first_stage | |
if not self.only_first_stage or self.srg is not None: | |
self.srg.dsr.model.cpu() | |
self.srg.itersr.model.cpu() | |
torch.cuda.empty_cache() | |
self.model.transformer.to(self.device) | |
tokens = self.generate_tokens(seq, txt_len, num) | |
if not self.only_first_stage: | |
self.model.transformer.cpu() | |
torch.cuda.empty_cache() | |
self.srg.dsr.model.to(self.device) | |
self.srg.itersr.model.to(self.device) | |
torch.cuda.empty_cache() | |
res = self.generate_images(seq, txt_len, tokens) | |
elapsed = time.perf_counter() - start | |
logger.info(f'Elapsed: {elapsed}') | |
logger.info('==================== done ====================') | |
return res | |
def preprocess_text( | |
self, text: str) -> tuple[torch.Tensor, int] | tuple[None, None]: | |
logger.info('--- preprocess_text ---') | |
start = time.perf_counter() | |
text = self.query_template.format(text) | |
logger.debug(f'{text=}') | |
seq = tokenizer.encode(text) | |
logger.info(f'{len(seq)=}') | |
if len(seq) > 110: | |
logger.info('The input text is too long.') | |
return None, None | |
txt_len = len(seq) - 1 | |
seq = torch.tensor(seq + [-1] * 400, device=self.device) | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return seq, txt_len | |
def generate_tokens(self, | |
seq: torch.Tensor, | |
txt_len: int, | |
num: int = 8) -> torch.Tensor: | |
logger.info('--- generate_tokens ---') | |
start = time.perf_counter() | |
# calibrate text length | |
log_attention_weights = torch.zeros( | |
len(seq), | |
len(seq), | |
device=self.device, | |
dtype=torch.half if self.fp16 else torch.float32) | |
log_attention_weights[:, :txt_len] = self.args.attn_plus | |
get_func = functools.partial(get_masks_and_position_ids_coglm, | |
context_length=txt_len) | |
output_list = [] | |
remaining = num | |
for _ in range((num + self.max_batch_size - 1) // self.max_batch_size): | |
self.strategy.start_pos = txt_len + 1 | |
coarse_samples = filling_sequence( | |
self.model, | |
seq.clone(), | |
batch_size=min(remaining, self.max_batch_size), | |
strategy=self.strategy, | |
log_attention_weights=log_attention_weights, | |
get_masks_and_position_ids=get_func)[0] | |
output_list.append(coarse_samples) | |
remaining -= self.max_batch_size | |
output_tokens = torch.cat(output_list, dim=0) | |
logger.debug(f'{output_tokens.shape=}') | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return output_tokens | |
def postprocess(tensor: torch.Tensor) -> np.ndarray: | |
return tensor.cpu().mul(255).add_(0.5).clamp_(0, 255).permute( | |
1, 2, 0).to(torch.uint8).numpy() | |
def generate_images(self, seq: torch.Tensor, txt_len: int, | |
tokens: torch.Tensor) -> list[np.ndarray]: | |
logger.info('--- generate_images ---') | |
start = time.perf_counter() | |
logger.debug(f'{self.only_first_stage=}') | |
res = [] | |
if self.only_first_stage: | |
for i in range(len(tokens)): | |
seq = tokens[i] | |
decoded_img = tokenizer.decode(image_ids=seq[-400:]) | |
decoded_img = torch.nn.functional.interpolate(decoded_img, | |
size=(480, 480)) | |
decoded_img = self.postprocess(decoded_img[0]) | |
res.append(decoded_img) # only the last image (target) | |
else: # sr | |
iter_tokens = self.srg.sr_base(tokens[:, -400:], seq[:txt_len]) | |
for seq in iter_tokens: | |
decoded_img = tokenizer.decode(image_ids=seq[-3600:]) | |
decoded_img = torch.nn.functional.interpolate(decoded_img, | |
size=(480, 480)) | |
decoded_img = self.postprocess(decoded_img[0]) | |
res.append(decoded_img) # only the last image (target) | |
elapsed = time.perf_counter() - start | |
logger.info(f'--- done ({elapsed=:.3f}) ---') | |
return res | |
class AppModel(Model): | |
def __init__(self, max_inference_batch_size: int, only_first_stage: bool): | |
super().__init__(max_inference_batch_size, only_first_stage) | |
self.translator = gr.Interface.load( | |
'spaces/chinhon/translation_eng2ch') | |
self.rng = random.Random() | |
def make_grid(self, images: list[np.ndarray] | None) -> np.ndarray | None: | |
if images is None or len(images) == 0: | |
return None | |
ncols = 1 | |
while True: | |
if ncols**2 >= len(images): | |
break | |
ncols += 1 | |
nrows = (len(images) + ncols - 1) // ncols | |
h, w = images[0].shape[:2] | |
grid = np.zeros((h * nrows, w * ncols, 3), dtype=np.uint8) | |
for i in range(nrows): | |
for j in range(ncols): | |
index = ncols * i + j | |
if index >= len(images): | |
break | |
grid[h * i:h * (i + 1), w * j:w * (j + 1)] = images[index] | |
return grid | |
def run_advanced( | |
self, text: str, translate: bool, style: str, seed: int, | |
only_first_stage: bool, num: int | |
) -> tuple[str | None, np.ndarray | None, list[np.ndarray] | None]: | |
logger.info( | |
f'{text=}, {translate=}, {style=}, {seed=}, {only_first_stage=}, {num=}' | |
) | |
if translate: | |
text = translated_text = self.translator(text) | |
else: | |
translated_text = None | |
results = self.run(text, style, seed, only_first_stage, num) | |
grid_image = self.make_grid(results) | |
return translated_text, grid_image, results | |
def run_simple(self, text: str) -> np.ndarray | None: | |
logger.info(f'{text=}') | |
if text.isascii(): | |
text = self.translator(text) | |
seed = self.rng.randint(0, 100000) | |
results = self.run(text, 'photo', seed, False, 4) | |
grid_image = self.make_grid(results) | |
return grid_image | |