import os from dataclasses import dataclass import torch from einops import rearrange from huggingface_hub import hf_hub_download from safetensors.torch import load_file as load_sft from .model import Flux, FluxParams from .controlnet import ControlNetFlux from .modules.autoencoder import AutoEncoder, AutoEncoderParams from .modules.conditioner import HFEmbedder from safetensors import safe_open def load_safetensors(path): tensors = {} with safe_open(path, framework="pt", device="cpu") as f: for key in f.keys(): tensors[key] = f.get_tensor(key) return tensors @dataclass class ModelSpec: params: FluxParams ae_params: AutoEncoderParams ckpt_path: str | None ae_path: str | None repo_id: str | None repo_flow: str | None repo_ae: str | None configs = { "flux-dev": ModelSpec( repo_id="black-forest-labs/FLUX.1-dev", repo_flow="flux1-dev.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_DEV"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=True, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), "flux-schnell": ModelSpec( repo_id="black-forest-labs/FLUX.1-schnell", repo_flow="flux1-schnell.safetensors", repo_ae="ae.safetensors", ckpt_path=os.getenv("FLUX_SCHNELL"), params=FluxParams( in_channels=64, vec_in_dim=768, context_in_dim=4096, hidden_size=3072, mlp_ratio=4.0, num_heads=24, depth=19, depth_single_blocks=38, axes_dim=[16, 56, 56], theta=10_000, qkv_bias=True, guidance_embed=False, ), ae_path=os.getenv("AE"), ae_params=AutoEncoderParams( resolution=256, in_channels=3, ch=128, out_ch=3, ch_mult=[1, 2, 4, 4], num_res_blocks=2, z_channels=16, scale_factor=0.3611, shift_factor=0.1159, ), ), } def print_load_warning(missing: list[str], unexpected: list[str]) -> None: if len(missing) > 0 and len(unexpected) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) print("\n" + "-" * 79 + "\n") print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) elif len(missing) > 0: print(f"Got {len(missing)} missing keys:\n\t" + "\n\t".join(missing)) elif len(unexpected) > 0: print(f"Got {len(unexpected)} unexpected keys:\n\t" + "\n\t".join(unexpected)) def load_flow_model(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params).to(torch.bfloat16) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_flow_model2(name: str, device: str | torch.device = "cuda", hf_download: bool = True): # Loading Flux print("Init model") ckpt_path = configs[name].ckpt_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_flow is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors")) with torch.device("meta" if ckpt_path is not None else device): model = Flux(configs[name].params) if ckpt_path is not None: print("Loading checkpoint") # load_sft doesn't support torch.device sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = model.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return model def load_controlnet(name, device, transformer=None): with torch.device(device): controlnet = ControlNetFlux(configs[name].params) if transformer is not None: controlnet.load_state_dict(transformer.state_dict(), strict=False) return controlnet def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder: # max length 64, 128, 256 and 512 should work (if your sequence is short enough) return HFEmbedder("google/t5-v1_1-xxl", max_length=max_length, torch_dtype=torch.bfloat16).to(device) def load_clip(device: str | torch.device = "cuda") -> HFEmbedder: return HFEmbedder("openai/clip-vit-large-patch14", max_length=77, torch_dtype=torch.bfloat16).to(device) def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder: ckpt_path = configs[name].ae_path if ( ckpt_path is None and configs[name].repo_id is not None and configs[name].repo_ae is not None and hf_download ): ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) # Loading the autoencoder print("Init AE") with torch.device("meta" if ckpt_path is not None else device): ae = AutoEncoder(configs[name].ae_params) if ckpt_path is not None: sd = load_sft(ckpt_path, device=str(device)) missing, unexpected = ae.load_state_dict(sd, strict=False, assign=True) print_load_warning(missing, unexpected) return ae class WatermarkEmbedder: def __init__(self, watermark): self.watermark = watermark self.num_bits = len(WATERMARK_BITS) self.encoder = WatermarkEncoder() self.encoder.set_watermark("bits", self.watermark) def __call__(self, image: torch.Tensor) -> torch.Tensor: """ Adds a predefined watermark to the input image Args: image: ([N,] B, RGB, H, W) in range [-1, 1] Returns: same as input but watermarked """ image = 0.5 * image + 0.5 squeeze = len(image.shape) == 4 if squeeze: image = image[None, ...] n = image.shape[0] image_np = rearrange((255 * image).detach().cpu(), "n b c h w -> (n b) h w c").numpy()[:, :, :, ::-1] # torch (b, c, h, w) in [0, 1] -> numpy (b, h, w, c) [0, 255] # watermarking libary expects input as cv2 BGR format for k in range(image_np.shape[0]): image_np[k] = self.encoder.encode(image_np[k], "dwtDct") image = torch.from_numpy(rearrange(image_np[:, :, :, ::-1], "(n b) h w c -> n b c h w", n=n)).to( image.device ) image = torch.clamp(image / 255, min=0.0, max=1.0) if squeeze: image = image[0] image = 2 * image - 1 return image # A fixed 48-bit message that was choosen at random WATERMARK_MESSAGE = 0b001010101111111010000111100111001111010100101110 # bin(x)[2:] gives bits of x as str, use int to convert them to 0/1 WATERMARK_BITS = [int(bit) for bit in bin(WATERMARK_MESSAGE)[2:]]