Tarin Clanuwat commited on
Commit
3313343
0 Parent(s):

initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: diffusers
3
+ license: apache-2.0
4
+ language:
5
+ - ja
6
+ pipeline_tag: text-to-image
7
+ tags:
8
+ - stable-diffusion
9
+ ---
10
+ # 🐟 Evo-Ukiyoe-v1
11
+
12
+ 🤗 [Models](https://huggingface.co/SakanaAI/Evo-Ukiyoe-v1/) | 📝 [Blog](https://sakana.ai/evo-ukiyoe/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
13
+
14
+
15
+ **Evo-Ukiyoe-v1** is an experimental education-purpose Japanese woodblock print Ukiyoe style image generation model. The model was train based on Sakana AI's [Evo-SDXL-JP](https://huggingface.co/SakanaAI/EvoSDXL-JP-v1).
16
+ All the dataset used to train Evo-Ukiyoe comes from Ukiyoe images belonged to [Ritsumeikan University, Art Research Center](https://www.arc.ritsumei.ac.jp/).
17
+
18
+ Please refer to our [blog](https://sakana.ai/evo-ukiyoe/) for more details.
19
+
20
+ ## Usage
21
+
22
+ Use the code below to get started with the model.
23
+
24
+
25
+ <details>
26
+ <summary> Click to expand </summary>
27
+
28
+ 1. Git clone this model card
29
+ ```
30
+ git clone https://huggingface.co/SakanaAI/Evo-Ukiyoe-v1
31
+ ```
32
+ 2. Install git-lfs if you don't have it yet.
33
+ ```
34
+ sudo apt install git-lfs
35
+ git lfs install
36
+ ```
37
+ 3. Create conda env
38
+ ```
39
+ conda create -n evo-ukiyoe python=3.11
40
+ conda activate evo-ukiyoe
41
+ ```
42
+ 4. Install packages
43
+ ```
44
+ cd Evo-Ukiyoe-v1
45
+ pip install -r requirements.txt
46
+ ```
47
+ 5. Run
48
+ ```python
49
+ from evo_ukiyoe_v1 import load_evo_ukiyoe
50
+
51
+ prompt = "着物を着ている猫が庭でお茶を飲んでいる。"
52
+ pipe = load_evo_ukiyoe(device="cuda")
53
+ images = pipe(prompt + "輻の浮世絵。超詳細。", negative_prompt='', guidance_scale=8.0, num_inference_steps=40).images
54
+ images[0].save("image.png")
55
+ ```
56
+
57
+ </details>
58
+
59
+
60
+
61
+ ## Model Details
62
+
63
+ <!-- Provide a longer summary of what this model is. -->
64
+
65
+ - **Developed by:** [Sakana AI](https://sakana.ai/)
66
+ - **Model type:** Diffusion-based text-to-image generative model
67
+ - **Language(s):** Japanese
68
+ - **Blog:** https://sakana.ai/evo-ukiyoe/
69
+
70
+
71
+ ## License
72
+ The Python script included in this repository and Lora weight are licensed under the [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
73
+ Please note that the license for the model/pipeline generated by this script is inherited from the source models.
74
+
75
+ ## Uses
76
+ This model is provided for research and development purposes only and should be considered as an experimental prototype.
77
+ It is not intended for commercial use or deployment in mission-critical environments.
78
+ Use of this model is at the user's own risk, and its performance and outcomes are not guaranteed.
79
+ Sakana AI shall not be liable for any direct, indirect, special, incidental, or consequential damages, or any loss arising from the use of this model, regardless of the results obtained.
80
+ Users must fully understand the risks associated with the use of this model and use it at their own discretion.
81
+
82
+
83
+ ## Acknowledgement
84
+
85
+ Evo-Ukiyoe was trained based on Evo-SDXL-JP. We would like to thank the developers of Evo-SDXL-JP source models for their contributions and for making their work available.
86
+ - [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
87
+ - [Juggernaut-XL-v9](https://huggingface.co/RunDiffusion/Juggernaut-XL-v9)
88
+ - [SDXL-DPO](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1)
89
+ - [JSDXL](https://huggingface.co/stabilityai/japanese-stable-diffusion-xl)
90
+
91
+ ## Citation
92
+
93
+ @misc{Evo-Ukiyoe,
94
+ url = {[https://huggingface.co/SakanaAI/Evo-Nishikie-v1](https://huggingface.co/SakanaAI/Evo-Nishikie-v1)},
95
+ title = {Evo-Ukiyoe},
96
+ author = {Clanuwat, Tarin and Shing, Makoto and Imajuku, Yuki and Kitamoto, Asanobu and Akama, Ryo}
97
+ }
evo_ukiyoe_v1.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ from typing import Dict, List, Union
4
+
5
+ from diffusers import (
6
+ StableDiffusionXLPipeline,
7
+ UNet2DConditionModel,
8
+ )
9
+ from huggingface_hub import hf_hub_download
10
+ import safetensors
11
+ import torch
12
+ from tqdm import tqdm
13
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
14
+
15
+
16
+ # Base models
17
+ SDXL_REPO = "stabilityai/stable-diffusion-xl-base-1.0"
18
+ DPO_REPO = "mhdang/dpo-sdxl-text2image-v1"
19
+ JN_REPO = "RunDiffusion/Juggernaut-XL-v9"
20
+ JSDXL_REPO = "stabilityai/japanese-stable-diffusion-xl"
21
+
22
+ # Evo-Ukiyoe
23
+ UKIYOE_REPO = "SakanaAI/Evo-Ukiyoe-v1"
24
+
25
+
26
+ def load_state_dict(checkpoint_file: Union[str, os.PathLike], device: str = "cpu"):
27
+ file_extension = os.path.basename(checkpoint_file).split(".")[-1]
28
+ if file_extension == "safetensors":
29
+ return safetensors.torch.load_file(checkpoint_file, device=device)
30
+ else:
31
+ return torch.load(checkpoint_file, map_location=device)
32
+
33
+
34
+ def load_from_pretrained(
35
+ repo_id,
36
+ filename="diffusion_pytorch_model.fp16.safetensors",
37
+ subfolder="unet",
38
+ device="cuda",
39
+ ) -> Dict[str, torch.Tensor]:
40
+ return load_state_dict(
41
+ hf_hub_download(
42
+ repo_id=repo_id,
43
+ filename=filename,
44
+ subfolder=subfolder,
45
+ ),
46
+ device=device,
47
+ )
48
+
49
+
50
+ def reshape_weight_task_tensors(task_tensors, weights):
51
+ """
52
+ Reshapes `weights` to match the shape of `task_tensors` by unsqeezing in the remaining dimenions.
53
+
54
+ Args:
55
+ task_tensors (`torch.Tensor`): The tensors that will be used to reshape `weights`.
56
+ weights (`torch.Tensor`): The tensor to be reshaped.
57
+
58
+ Returns:
59
+ `torch.Tensor`: The reshaped tensor.
60
+ """
61
+ new_shape = weights.shape + (1,) * (task_tensors.dim() - weights.dim())
62
+ weights = weights.view(new_shape)
63
+ return weights
64
+
65
+
66
+ def linear(task_tensors: List[torch.Tensor], weights: torch.Tensor) -> torch.Tensor:
67
+ """
68
+ Merge the task tensors using `linear`.
69
+
70
+ Args:
71
+ task_tensors(`List[torch.Tensor]`):The task tensors to merge.
72
+ weights (`torch.Tensor`):The weights of the task tensors.
73
+
74
+ Returns:
75
+ `torch.Tensor`: The merged tensor.
76
+ """
77
+ task_tensors = torch.stack(task_tensors, dim=0)
78
+ # weighted task tensors
79
+ weights = reshape_weight_task_tensors(task_tensors, weights)
80
+ weighted_task_tensors = task_tensors * weights
81
+ mixed_task_tensors = weighted_task_tensors.sum(dim=0)
82
+ return mixed_task_tensors
83
+
84
+
85
+ def merge_models(task_tensors, weights):
86
+ keys = list(task_tensors[0].keys())
87
+ weights = torch.tensor(weights, device=task_tensors[0][keys[0]].device)
88
+ state_dict = {}
89
+ for key in tqdm(keys, desc="Merging"):
90
+ w_list = []
91
+ for i, sd in enumerate(task_tensors):
92
+ w = sd.pop(key)
93
+ w_list.append(w)
94
+ new_w = linear(task_tensors=w_list, weights=weights)
95
+ state_dict[key] = new_w
96
+ return state_dict
97
+
98
+
99
+ def split_conv_attn(weights):
100
+ attn_tensors = {}
101
+ conv_tensors = {}
102
+ for key in list(weights.keys()):
103
+ if any(k in key for k in ["to_k", "to_q", "to_v", "to_out.0"]):
104
+ attn_tensors[key] = weights.pop(key)
105
+ else:
106
+ conv_tensors[key] = weights.pop(key)
107
+ return {"conv": conv_tensors, "attn": attn_tensors}
108
+
109
+
110
+ def load_evo_ukiyoe(device="cuda") -> StableDiffusionXLPipeline:
111
+ # Load base models
112
+ sdxl_weights = split_conv_attn(load_from_pretrained(SDXL_REPO, device=device))
113
+ dpo_weights = split_conv_attn(
114
+ load_from_pretrained(
115
+ DPO_REPO, "diffusion_pytorch_model.safetensors", device=device
116
+ )
117
+ )
118
+ jn_weights = split_conv_attn(load_from_pretrained(JN_REPO, device=device))
119
+ jsdxl_weights = split_conv_attn(load_from_pretrained(JSDXL_REPO, device=device))
120
+
121
+ # Merge base models
122
+ tensors = [sdxl_weights, dpo_weights, jn_weights, jsdxl_weights]
123
+ new_conv = merge_models(
124
+ [sd["conv"] for sd in tensors],
125
+ [
126
+ 0.15928833971605916,
127
+ 0.1032449268871776,
128
+ 0.6503217149752791,
129
+ 0.08714501842148402,
130
+ ],
131
+ )
132
+ new_attn = merge_models(
133
+ [sd["attn"] for sd in tensors],
134
+ [
135
+ 0.1877279276437178,
136
+ 0.20014114603909822,
137
+ 0.3922685507065275,
138
+ 0.2198623756106564,
139
+ ],
140
+ )
141
+
142
+ # Delete no longer needed variables to free
143
+ del sdxl_weights, dpo_weights, jn_weights, jsdxl_weights
144
+ gc.collect()
145
+ if "cuda" in device:
146
+ torch.cuda.empty_cache()
147
+
148
+ # Instantiate UNet
149
+ unet_config = UNet2DConditionModel.load_config(SDXL_REPO, subfolder="unet")
150
+ unet = UNet2DConditionModel.from_config(unet_config).to(device=device)
151
+ unet.load_state_dict({**new_conv, **new_attn})
152
+
153
+ # Load other modules
154
+ text_encoder = CLIPTextModelWithProjection.from_pretrained(
155
+ JSDXL_REPO, subfolder="text_encoder", torch_dtype=torch.float16, variant="fp16",
156
+ )
157
+ tokenizer = AutoTokenizer.from_pretrained(
158
+ JSDXL_REPO, subfolder="tokenizer", use_fast=False,
159
+ )
160
+
161
+ # Load pipeline
162
+ pipe = StableDiffusionXLPipeline.from_pretrained(
163
+ SDXL_REPO,
164
+ unet=unet,
165
+ text_encoder=text_encoder,
166
+ tokenizer=tokenizer,
167
+ torch_dtype=torch.float16,
168
+ variant="fp16",
169
+ )
170
+
171
+ # Load Evo-Ukiyoe weights
172
+ pipe.load_lora_weights(UKIYOE_REPO)
173
+ pipe.fuse_lora(lora_scale=1.0)
174
+
175
+ pipe = pipe.to(device=torch.device(device), dtype=torch.float16)
176
+ return pipe
pytorch_lora_weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02ea124249f2bc80db556f9b81d3c98ec3a256b00885b17b5450c0d7a7d0e9c0
3
+ size 59519264
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+
4
+ accelerate==0.32.0
5
+ diffusers==0.29.2
6
+ sentencepiece==0.2.0
7
+ transformers==4.42.3
8
+ peft==0.11.1