Spaces:
Sleeping
Sleeping
Ruining Li
commited on
Commit
•
8cede4e
1
Parent(s):
42a369e
Updated lfs for checkpoints and changes to model
Browse files- .gitignore +1 -2
- ckpts/drag-a-part-final.pt +3 -0
- ckpts/sam_vit_h_4b8939.pth +3 -0
- ckpts/stable-diffusion-v1-5/unet/config.json +36 -0
- model.py +3 -36
.gitignore
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
__pycache__/
|
2 |
-
ckpts/
|
|
|
1 |
+
__pycache__/
|
|
ckpts/drag-a-part-final.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:162ba2040b59ed949fd9f57c861bb07eec56744d2e738e38ada8724de96d0d32
|
3 |
+
size 14265312095
|
ckpts/sam_vit_h_4b8939.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
|
3 |
+
size 2564550879
|
ckpts/stable-diffusion-v1-5/unet/config.json
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_class_name": "UNet2DConditionModel",
|
3 |
+
"_diffusers_version": "0.6.0",
|
4 |
+
"act_fn": "silu",
|
5 |
+
"attention_head_dim": 8,
|
6 |
+
"block_out_channels": [
|
7 |
+
320,
|
8 |
+
640,
|
9 |
+
1280,
|
10 |
+
1280
|
11 |
+
],
|
12 |
+
"center_input_sample": false,
|
13 |
+
"cross_attention_dim": 768,
|
14 |
+
"down_block_types": [
|
15 |
+
"CrossAttnDownBlock2D",
|
16 |
+
"CrossAttnDownBlock2D",
|
17 |
+
"CrossAttnDownBlock2D",
|
18 |
+
"DownBlock2D"
|
19 |
+
],
|
20 |
+
"downsample_padding": 1,
|
21 |
+
"flip_sin_to_cos": true,
|
22 |
+
"freq_shift": 0,
|
23 |
+
"in_channels": 4,
|
24 |
+
"layers_per_block": 2,
|
25 |
+
"mid_block_scale_factor": 1,
|
26 |
+
"norm_eps": 1e-05,
|
27 |
+
"norm_num_groups": 32,
|
28 |
+
"out_channels": 4,
|
29 |
+
"sample_size": 64,
|
30 |
+
"up_block_types": [
|
31 |
+
"UpBlock2D",
|
32 |
+
"CrossAttnUpBlock2D",
|
33 |
+
"CrossAttnUpBlock2D",
|
34 |
+
"CrossAttnUpBlock2D"
|
35 |
+
]
|
36 |
+
}
|
model.py
CHANGED
@@ -1255,20 +1255,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1255 |
)
|
1256 |
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
1257 |
raise NotImplementedError
|
1258 |
-
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
1259 |
-
in_channels=block_out_channels[-1],
|
1260 |
-
temb_channels=blocks_time_embed_dim,
|
1261 |
-
resnet_eps=norm_eps,
|
1262 |
-
resnet_act_fn=act_fn,
|
1263 |
-
output_scale_factor=mid_block_scale_factor,
|
1264 |
-
cross_attention_dim=cross_attention_dim[-1],
|
1265 |
-
attention_head_dim=attention_head_dim[-1],
|
1266 |
-
resnet_groups=norm_num_groups,
|
1267 |
-
resnet_time_scale_shift=resnet_time_scale_shift,
|
1268 |
-
skip_time_act=resnet_skip_time_act,
|
1269 |
-
only_cross_attention=mid_block_only_cross_attention,
|
1270 |
-
cross_attention_norm=cross_attention_norm,
|
1271 |
-
)
|
1272 |
elif mid_block_type is None:
|
1273 |
self.mid_block = None
|
1274 |
else:
|
@@ -1512,11 +1498,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1512 |
y1 = y1.unsqueeze(-1).unsqueeze(-1)
|
1513 |
y1 = torch.stack([torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
|
1514 |
|
1515 |
-
# assert torch.all(x_src >= 0) and torch.all(x_src <= 1)
|
1516 |
-
# assert torch.all(y_src >= 0) and torch.all(y_src <= 1)
|
1517 |
-
# assert torch.all(x_tgt >= 0) and torch.all(x_tgt <= 1)
|
1518 |
-
# assert torch.all(y_tgt >= 0) and torch.all(y_tgt <= 1)
|
1519 |
-
|
1520 |
value_image = torch.stack([x_src, y_src, x_tgt, y_tgt], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
|
1521 |
value_image = value_image.expand(bsz, 4 * self.num_drags, current_resolution, current_resolution)
|
1522 |
|
@@ -1527,18 +1508,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1527 |
|
1528 |
def forward(
|
1529 |
self,
|
1530 |
-
# sample: torch.FloatTensor,
|
1531 |
-
# timestep: Union[torch.Tensor, float, int],
|
1532 |
-
# encoder_hidden_states: torch.Tensor,
|
1533 |
-
# class_labels: Optional[torch.Tensor] = None,
|
1534 |
-
# timestep_cond: Optional[torch.Tensor] = None,
|
1535 |
-
# attention_mask: Optional[torch.Tensor] = None,
|
1536 |
-
# cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
1537 |
-
# added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
|
1538 |
-
# down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
|
1539 |
-
# mid_block_additional_residual: Optional[torch.Tensor] = None,
|
1540 |
-
# encoder_attention_mask: Optional[torch.Tensor] = None,
|
1541 |
-
# return_dict: bool = True,
|
1542 |
x: torch.FloatTensor,
|
1543 |
t: torch.Tensor,
|
1544 |
x_cond: torch.FloatTensor,
|
@@ -1546,7 +1515,6 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1546 |
force_drop_ids: Optional[torch.Tensor] = None,
|
1547 |
hidden_cls: Optional[torch.Tensor] = None,
|
1548 |
drags: Optional[torch.Tensor] = None,
|
1549 |
-
save_features: bool = False,
|
1550 |
) -> torch.Tensor:
|
1551 |
r"""
|
1552 |
The [`UNet2DConditionModel`] forward method.
|
@@ -1941,11 +1909,10 @@ class UNet2DDragConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMi
|
|
1941 |
from diffusers.utils import WEIGHTS_NAME
|
1942 |
one_sided_attn = unet_additional_kwargs.pop("one_sided_attn", True) if unet_additional_kwargs is not None else True
|
1943 |
model = cls.from_config(config, **unet_additional_kwargs) if unet_additional_kwargs is not None else cls.from_config(config)
|
1944 |
-
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1945 |
-
if not os.path.isfile(model_file):
|
1946 |
-
raise RuntimeError(f"{model_file} does not exist")
|
1947 |
-
|
1948 |
if load:
|
|
|
|
|
|
|
1949 |
state_dict = torch.load(model_file, map_location="cpu")
|
1950 |
m, u = model.load_state_dict(state_dict, strict=False)
|
1951 |
|
|
|
1255 |
)
|
1256 |
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
1257 |
raise NotImplementedError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1258 |
elif mid_block_type is None:
|
1259 |
self.mid_block = None
|
1260 |
else:
|
|
|
1498 |
y1 = y1.unsqueeze(-1).unsqueeze(-1)
|
1499 |
y1 = torch.stack([torch.zeros_like(y1) - 1, torch.zeros_like(y1) - 1, y1, y1], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
|
1500 |
|
|
|
|
|
|
|
|
|
|
|
1501 |
value_image = torch.stack([x_src, y_src, x_tgt, y_tgt], dim=2).view(bsz, 4 * self.num_drags, 1, 1)
|
1502 |
value_image = value_image.expand(bsz, 4 * self.num_drags, current_resolution, current_resolution)
|
1503 |
|
|
|
1508 |
|
1509 |
def forward(
|
1510 |
self,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1511 |
x: torch.FloatTensor,
|
1512 |
t: torch.Tensor,
|
1513 |
x_cond: torch.FloatTensor,
|
|
|
1515 |
force_drop_ids: Optional[torch.Tensor] = None,
|
1516 |
hidden_cls: Optional[torch.Tensor] = None,
|
1517 |
drags: Optional[torch.Tensor] = None,
|
|
|
1518 |
) -> torch.Tensor:
|
1519 |
r"""
|
1520 |
The [`UNet2DConditionModel`] forward method.
|
|
|
1909 |
from diffusers.utils import WEIGHTS_NAME
|
1910 |
one_sided_attn = unet_additional_kwargs.pop("one_sided_attn", True) if unet_additional_kwargs is not None else True
|
1911 |
model = cls.from_config(config, **unet_additional_kwargs) if unet_additional_kwargs is not None else cls.from_config(config)
|
|
|
|
|
|
|
|
|
1912 |
if load:
|
1913 |
+
model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
|
1914 |
+
if not os.path.isfile(model_file):
|
1915 |
+
raise RuntimeError(f"{model_file} does not exist")
|
1916 |
state_dict = torch.load(model_file, map_location="cpu")
|
1917 |
m, u = model.load_state_dict(state_dict, strict=False)
|
1918 |
|