File size: 1,832 Bytes
3de264f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import matplotlib.pyplot as plt
from PIL import Image
plt.rcParams["figure.figsize"] = (10, 5)
plt.rcParams['figure.facecolor'] = 'white'
def render_figure(model_name, fn):
image_types = ['bird', 'human', 'room', 'vermeer']
def plot_row(axs, control_fn_prefix, output_fn_prefix, name, show_control=False):
for i, ax in enumerate(axs):
if i == 0:
if show_control:
ax.set_title(f'Control')
ax.imshow(Image.open(f'{control_fn_prefix}.png'))
else:
ax.set_title(f'Seed={i-1} ({name})')
ax.imshow(Image.open(f'{output_fn_prefix}_{i-1}.png'))
fig, axs = plt.subplots(
2 * len(image_types), 5, layout="constrained", figsize=(10, 5 * len(image_types)))
for ax in axs.flatten():
ax.set_aspect('equal', 'box')
ax.axis('off')
pair_axs = [list(pair) for pair in zip(axs[::2], axs[1::2])]
for image_type, pair_ax in zip(image_types, pair_axs):
plot_row(pair_ax[0],
f'./control_images/converted/control_{image_type}_{model_name}',
f'./output_images/diffusers/output_{image_type}_{model_name}',
'Diffusers', show_control=True)
plot_row(pair_ax[1],
f'./control_images/converted/control_{image_type}_{model_name}',
f'./output_images/ref/output_{image_type}_{model_name}',
'ref impl.')
fig.suptitle(f'Model: {model_name}', fontsize=16)
# fig.tight_layout()
fig.savefig(fn, dpi=144)
if __name__ == '__main__':
model_names = ["canny", "normal", "depth",
"openpose", "hed", "scribble", "mlsd", "seg"]
for model in model_names:
fn = f"plots/figure_{model}.png"
render_figure(model, fn)
print(fn)
|