File size: 2,254 Bytes
ddfc4d0
fa128ec
 
 
 
 
 
 
ddfc4d0
 
 
fa128ec
ddfc4d0
fa128ec
 
ddfc4d0
 
 
fa128ec
 
 
 
 
 
 
ddfc4d0
 
 
 
 
 
 
fa128ec
 
 
d58b310
ddfc4d0
d58b310
 
ddfc4d0
 
d58b310
fa128ec
 
d58b310
 
ddfc4d0
d58b310
fa128ec
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import math
import torch
import torchvision
import gradio as gr
from PIL import Image
from cli import iterative_refinement
from viz import grid_of_images_default
models = {
    "ConvAE": torch.load("convae.th", map_location="cpu"),
    "Deep ConvAE": torch.load("deep_convae.th", map_location="cpu"),
    "Dense K-Sparse": torch.load("fc_sparse.th", map_location="cpu"),
}
def gen(md, model_name, seed, nb_iter, nb_samples, width, height, nb_active, only_last, black_bg):
    torch.manual_seed(int(seed))
    bs = 64
    model = models[model_name]
    if model == "Dense K-Sparse":
        model.nb_active = nb_active
    samples = iterative_refinement(
        model, 
        nb_iter=int(nb_iter),
        nb_examples=int(nb_samples), 
        w=int(width), h=int(height), c=1, 
        batch_size=bs,
    )
    if only_last:
        s = int(math.sqrt((nb_samples)))
        grid = grid_of_images_default(samples[-1].numpy(), shape=(s, s)) 
    else:
        grid = grid_of_images_default(samples.reshape((samples.shape[0]*samples.shape[1], int(height), int(width), 1)).numpy(), shape=(samples.shape[0], samples.shape[1])) 
    if not black_bg:
        grid = 1 - grid
    grid = (grid*255).astype("uint8")
    return Image.fromarray(grid)

text = """
Interface with ConvAE model (from [here](https://arxiv.org/pdf/1606.04345.pdf)) and DeepConvAE model (from [here](https://tel.archives-ouvertes.fr/tel-01838272/file/75406_CHERTI_2018_diffusion.pdf), Section 10.1 with `L=3`), Dense K-Sparse model (from [here](https://openreview.net/forum?id=r1QXQkSYg))

These models were trained on MNIST only (digits), but were found to generate new kinds of symbols, see the references for more details.

NB: `nb_active` is only used for the Dense K-Sparse, specifying nb of activations to keep in the last layer.
"""
iface = gr.Interface(
    fn=gen,
    inputs=[
        gr.Markdown(text),
        gr.Dropdown(list(models.keys()), value="Deep ConvAE"), gr.Number(value=0), gr.Number(value=25), gr.Number(value=1), gr.Number(value=28), gr.Number(value=28),gr.Slider(minimum=0,maximum=800, value=800, step=1), gr.Checkbox(value=False, label="Only show last iteration"), gr.Checkbox(value=True, label="Black background")
    ],
    outputs="image"
)
iface.launch()