Spaces:
Sleeping
Sleeping
Ruining Li
commited on
Commit
•
d9fca68
1
Parent(s):
e78b488
Adapt to HF ZeroGPU
Browse files- app.py +5 -1
- requirements.txt +1 -0
app.py
CHANGED
@@ -16,6 +16,8 @@ from torchvision import transforms
|
|
16 |
from diffusion import create_diffusion
|
17 |
from model import UNet2DDragConditionModel
|
18 |
|
|
|
|
|
19 |
|
20 |
TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
|
21 |
DESCRIPTION = """
|
@@ -93,6 +95,7 @@ def model_init():
|
|
93 |
model = model.to("cuda")
|
94 |
return model
|
95 |
|
|
|
96 |
def sam_segment(predictor, input_image, drags, foreground_points=None):
|
97 |
image = np.asarray(input_image)
|
98 |
predictor.set_image(image)
|
@@ -169,6 +172,7 @@ def preprocess_image(SAM_predictor, img, chk_group, drags):
|
|
169 |
processed_img = image_pil.resize((256, 256), Image.LANCZOS)
|
170 |
return processed_img, new_drags
|
171 |
|
|
|
172 |
def single_image_sample(
|
173 |
model,
|
174 |
diffusion,
|
@@ -399,4 +403,4 @@ with gr.Blocks(title=TITLE) as demo:
|
|
399 |
outputs=[generated_image],
|
400 |
)
|
401 |
|
402 |
-
demo.launch(
|
|
|
16 |
from diffusion import create_diffusion
|
17 |
from model import UNet2DDragConditionModel
|
18 |
|
19 |
+
import spaces
|
20 |
+
|
21 |
|
22 |
TITLE = '''DragAPart: Learning a Part-Level Motion Prior for Articulated Objects'''
|
23 |
DESCRIPTION = """
|
|
|
95 |
model = model.to("cuda")
|
96 |
return model
|
97 |
|
98 |
+
@spaces.GPU
|
99 |
def sam_segment(predictor, input_image, drags, foreground_points=None):
|
100 |
image = np.asarray(input_image)
|
101 |
predictor.set_image(image)
|
|
|
172 |
processed_img = image_pil.resize((256, 256), Image.LANCZOS)
|
173 |
return processed_img, new_drags
|
174 |
|
175 |
+
@spaces.GPU
|
176 |
def single_image_sample(
|
177 |
model,
|
178 |
diffusion,
|
|
|
403 |
outputs=[generated_image],
|
404 |
)
|
405 |
|
406 |
+
demo.launch()
|
requirements.txt
CHANGED
@@ -9,3 +9,4 @@ tqdm
|
|
9 |
transformers
|
10 |
gradio
|
11 |
accelerate
|
|
|
|
9 |
transformers
|
10 |
gradio
|
11 |
accelerate
|
12 |
+
spaces
|