Spaces:
Runtime error
Runtime error
Ahsen Khaliq
commited on
Commit
•
9988281
1
Parent(s):
b16b0a3
Update app.py
Browse files
app.py
CHANGED
@@ -11,18 +11,17 @@ from torchvision import transforms
|
|
11 |
import torchtext
|
12 |
|
13 |
|
14 |
-
# Images
|
15 |
-
torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2021/08/04/14/16/tower-6521842_1280.jpg', 'tower.jpg')
|
16 |
-
torch.hub.download_url_to_file('https://cdn.pixabay.com/photo/2017/08/31/05/36/buildings-2699520_1280.jpg', 'city.jpg')
|
17 |
|
18 |
idx = 0
|
19 |
|
20 |
torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
|
21 |
|
22 |
|
23 |
-
def
|
24 |
result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
|
25 |
-
result
|
|
|
|
|
26 |
|
27 |
|
28 |
def param2stroke(param, H, W, meta_brushes):
|
@@ -87,7 +86,7 @@ def param2stroke(param, H, W, meta_brushes):
|
|
87 |
|
88 |
|
89 |
def param2img_serial(
|
90 |
-
param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None):
|
91 |
"""
|
92 |
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
|
93 |
and whether there is a border (if intermediate painting results are required).
|
@@ -146,7 +145,8 @@ def param2img_serial(
|
|
146 |
device=this_canvas.device)
|
147 |
selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
|
148 |
if selected_param[selected_decision, :].shape[0] > 0:
|
149 |
-
selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] =
|
|
|
150 |
selected_foregrounds = selected_foregrounds.view(
|
151 |
b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
152 |
selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
@@ -164,6 +164,11 @@ def param2img_serial(
|
|
164 |
factor = 2
|
165 |
else:
|
166 |
factor = 4
|
|
|
|
|
|
|
|
|
|
|
167 |
if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
168 |
for i in range(s):
|
169 |
canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
|
@@ -177,6 +182,7 @@ def param2img_serial(
|
|
177 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
178 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
179 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
|
|
180 |
|
181 |
if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
182 |
for i in range(s):
|
@@ -193,6 +199,7 @@ def param2img_serial(
|
|
193 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
194 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
195 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
|
|
196 |
|
197 |
if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
198 |
for i in range(s):
|
@@ -208,6 +215,7 @@ def param2img_serial(
|
|
208 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
209 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
210 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
|
|
211 |
|
212 |
if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
213 |
for i in range(s):
|
@@ -223,6 +231,7 @@ def param2img_serial(
|
|
223 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
224 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
225 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
|
|
226 |
|
227 |
cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
|
228 |
|
@@ -383,10 +392,10 @@ def crop(img, h, w):
|
|
383 |
|
384 |
|
385 |
def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
|
386 |
-
|
387 |
-
|
388 |
input_name = os.path.basename(input_path)
|
389 |
-
|
390 |
frame_dir = None
|
391 |
if need_animation:
|
392 |
if not serial:
|
@@ -416,6 +425,7 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
|
|
416 |
original_img_pad_size = patch_size * (2 ** K)
|
417 |
original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
|
418 |
final_result = torch.zeros_like(original_img_pad).to(device)
|
|
|
419 |
for layer in range(0, K + 1):
|
420 |
layer_size = patch_size * (2 ** layer)
|
421 |
img = F.interpolate(original_img_pad, (layer_size, layer_size))
|
@@ -449,7 +459,7 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
|
|
449 |
param[..., 2:4] = param[..., 2:4] / 2
|
450 |
if serial:
|
451 |
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
452 |
-
frame_dir, False, original_h, original_w)
|
453 |
else:
|
454 |
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
455 |
|
@@ -486,39 +496,39 @@ def main(input_path, model_path, output_dir, need_animation=False, resize_h=None
|
|
486 |
param[..., 2:4] = param[..., 2:4] / 2
|
487 |
if serial:
|
488 |
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
489 |
-
frame_dir, True, original_h, original_w)
|
490 |
else:
|
491 |
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
492 |
final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
|
493 |
|
494 |
final_result = crop(final_result, original_h, original_w)
|
495 |
-
|
496 |
tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
|
497 |
-
return tensor_to_pil
|
|
|
|
|
|
|
|
|
498 |
|
499 |
|
500 |
def gradio_inference(image):
|
501 |
return main(input_path=image.name,
|
502 |
model_path='model.pth',
|
503 |
output_dir='output/',
|
504 |
-
need_animation=
|
505 |
resize_h=512, # resize original input to this size. None means do not resize.
|
506 |
resize_w=512, # resize original input to this size. None means do not resize.
|
507 |
-
serial=
|
508 |
|
509 |
-
title = "
|
510 |
-
description = "
|
511 |
-
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/
|
512 |
|
513 |
gr.Interface(
|
514 |
gradio_inference,
|
515 |
[gr.inputs.Image(type="file", label="Input")],
|
516 |
-
gr.outputs.Image(type="
|
517 |
title=title,
|
518 |
description=description,
|
519 |
-
article=article
|
520 |
-
examples=[
|
521 |
-
['city.jpg'],
|
522 |
-
['tower.jpg']
|
523 |
-
]
|
524 |
).launch(debug=True)
|
|
|
11 |
import torchtext
|
12 |
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
idx = 0
|
16 |
|
17 |
torchtext.utils.download_from_url("https://drive.google.com/uc?id=1NDD54BLligyr8tzo8QGI5eihZisXK1nq", root=".")
|
18 |
|
19 |
|
20 |
+
def to_PIL_img(img):
|
21 |
result = Image.fromarray((img.data.cpu().numpy().transpose((1, 2, 0)) * 255).astype(np.uint8))
|
22 |
+
return result
|
23 |
+
def save_img(img, output_path):
|
24 |
+
to_PIL_img(img).save(output_path)
|
25 |
|
26 |
|
27 |
def param2stroke(param, H, W, meta_brushes):
|
|
|
86 |
|
87 |
|
88 |
def param2img_serial(
|
89 |
+
param, decision, meta_brushes, cur_canvas, frame_dir, has_border=False, original_h=None, original_w=None, *, all_frames):
|
90 |
"""
|
91 |
Input stroke parameters and decisions for each patch, meta brushes, current canvas, frame directory,
|
92 |
and whether there is a border (if intermediate painting results are required).
|
|
|
145 |
device=this_canvas.device)
|
146 |
selected_alphas = torch.zeros(selected_param.shape[0], 3, patch_size_y, patch_size_x, device=this_canvas.device)
|
147 |
if selected_param[selected_decision, :].shape[0] > 0:
|
148 |
+
selected_foregrounds[selected_decision, :, :, :], selected_alphas[selected_decision, :, :, :] = \
|
149 |
+
param2stroke(selected_param[selected_decision, :], patch_size_y, patch_size_x, meta_brushes)
|
150 |
selected_foregrounds = selected_foregrounds.view(
|
151 |
b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
152 |
selected_alphas = selected_alphas.view(b, selected_h, selected_w, 3, patch_size_y, patch_size_x).contiguous()
|
|
|
164 |
factor = 2
|
165 |
else:
|
166 |
factor = 4
|
167 |
+
|
168 |
+
def store_frame(img):
|
169 |
+
all_frames.append(to_PIL_img(img))
|
170 |
+
|
171 |
+
|
172 |
if even_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
173 |
for i in range(s):
|
174 |
canvas = partial_render(cur_canvas, even_y_even_x_coord_y, even_y_even_x_coord_x, i)
|
|
|
182 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
183 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
184 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
185 |
+
store_frame(frame[0])
|
186 |
|
187 |
if odd_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
188 |
for i in range(s):
|
|
|
199 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
200 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
201 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
202 |
+
store_frame(frame[0])
|
203 |
|
204 |
if odd_idx_y.shape[0] > 0 and even_idx_x.shape[0] > 0:
|
205 |
for i in range(s):
|
|
|
215 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
216 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
217 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
218 |
+
store_frame(frame[0])
|
219 |
|
220 |
if even_idx_y.shape[0] > 0 and odd_idx_x.shape[0] > 0:
|
221 |
for i in range(s):
|
|
|
231 |
frame = crop(cur_canvas[:, :, patch_size_y // factor:-patch_size_y // factor,
|
232 |
patch_size_x // factor:-patch_size_x // factor], original_h, original_w)
|
233 |
save_img(frame[0], os.path.join(frame_dir, '%03d.jpg' % idx))
|
234 |
+
store_frame(frame[0])
|
235 |
|
236 |
cur_canvas = cur_canvas[:, :, patch_size_y // 4:-patch_size_y // 4, patch_size_x // 4:-patch_size_x // 4]
|
237 |
|
|
|
392 |
|
393 |
|
394 |
def main(input_path, model_path, output_dir, need_animation=False, resize_h=None, resize_w=None, serial=False):
|
395 |
+
if not os.path.exists(output_dir):
|
396 |
+
os.mkdir(output_dir)
|
397 |
input_name = os.path.basename(input_path)
|
398 |
+
output_path = os.path.join(output_dir, input_name)
|
399 |
frame_dir = None
|
400 |
if need_animation:
|
401 |
if not serial:
|
|
|
425 |
original_img_pad_size = patch_size * (2 ** K)
|
426 |
original_img_pad = pad(original_img, original_img_pad_size, original_img_pad_size)
|
427 |
final_result = torch.zeros_like(original_img_pad).to(device)
|
428 |
+
all_frames = []
|
429 |
for layer in range(0, K + 1):
|
430 |
layer_size = patch_size * (2 ** layer)
|
431 |
img = F.interpolate(original_img_pad, (layer_size, layer_size))
|
|
|
459 |
param[..., 2:4] = param[..., 2:4] / 2
|
460 |
if serial:
|
461 |
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
462 |
+
frame_dir, False, original_h, original_w, all_frames = all_frames)
|
463 |
else:
|
464 |
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
465 |
|
|
|
496 |
param[..., 2:4] = param[..., 2:4] / 2
|
497 |
if serial:
|
498 |
final_result = param2img_serial(param, decision, meta_brushes, final_result,
|
499 |
+
frame_dir, True, original_h, original_w, all_frames = all_frames)
|
500 |
else:
|
501 |
final_result = param2img_parallel(param, decision, meta_brushes, final_result)
|
502 |
final_result = final_result[:, :, border_size:-border_size, border_size:-border_size]
|
503 |
|
504 |
final_result = crop(final_result, original_h, original_w)
|
505 |
+
save_img(final_result[0], output_path)
|
506 |
tensor_to_pil = transforms.ToPILImage()(final_result[0].squeeze_(0))
|
507 |
+
#return tensor_to_pil
|
508 |
+
|
509 |
+
all_frames[0].save(os.path.join(frame_dir, 'animation.gif'),
|
510 |
+
save_all=True, append_images=all_frames[1:], optimize=False, duration=40, loop=0)
|
511 |
+
return os.path.join(frame_dir, "animation.gif")
|
512 |
|
513 |
|
514 |
def gradio_inference(image):
|
515 |
return main(input_path=image.name,
|
516 |
model_path='model.pth',
|
517 |
output_dir='output/',
|
518 |
+
need_animation=True, # whether need intermediate results for animation.
|
519 |
resize_h=512, # resize original input to this size. None means do not resize.
|
520 |
resize_w=512, # resize original input to this size. None means do not resize.
|
521 |
+
serial=True) # if need animation, serial must be True.
|
522 |
|
523 |
+
title = "Anime2Sketch"
|
524 |
+
description = "demo for Anime2Sketch. To use it, simply upload your image, or click one of the examples to load them. Read more at the links below."
|
525 |
+
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2104.05703'>Adversarial Open Domain Adaption for Sketch-to-Photo Synthesis</a> | <a href='https://github.com/Mukosame/Anime2Sketch'>Github Repo</a></p>"
|
526 |
|
527 |
gr.Interface(
|
528 |
gradio_inference,
|
529 |
[gr.inputs.Image(type="file", label="Input")],
|
530 |
+
gr.outputs.Image(type="file", label="Output"),
|
531 |
title=title,
|
532 |
description=description,
|
533 |
+
article=article
|
|
|
|
|
|
|
|
|
534 |
).launch(debug=True)
|