|
import math |
|
import torch |
|
import torchvision |
|
import gradio as gr |
|
from PIL import Image |
|
from cli import iterative_refinement |
|
from viz import grid_of_images_default |
|
models = { |
|
"ConvAE": torch.load("convae.th", map_location="cpu"), |
|
"Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"), |
|
"Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"), |
|
} |
|
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg): |
|
torch.manual_seed(int(seed)) |
|
bs = 64 |
|
model = models[model_name] |
|
if model == "Dense K-Sparse": |
|
model.nb_active = nb_active |
|
samples = iterative_refinement( |
|
model, |
|
nb_iter=int(nb_iter), |
|
nb_examples=int(nb_samples), |
|
w=int(width), h=int(height), c=1, |
|
batch_size=bs, |
|
) |
|
if only_last: |
|
s = int(math.sqrt((nb_samples))) |
|
grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s)) |
|
else: |
|
grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1])) |
|
if not black_bg: |
|
grid = 1 - grid |
|
grid = (grid*255).astype("uint8") |
|
return Image.fromarray(grid) |
|
|
|
text = """ |
|
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`), Dense K-Sparse model (from [here](https://openreview.net/forum?id=r1QXQkSYg)) |
|
|
|
These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details. |
|
|
|
NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer. |
|
""" |
|
iface = gr.Interface( |
|
fn=gen, |
|
inputs=[ |
|
gr.Markdown(text), |
|
gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background") |
|
], |
|
outputs="image" |
|
) |
|
iface.launch() |
|
|