File size: 3,004 Bytes
e1a92c5
22a3fc5
 
8ec7f9b
22a3fc5
 
 
 
 
e1a92c5
22a3fc5
 
 
e1a92c5
22a3fc5
 
 
 
 
 
 
 
 
 
 
8ec7f9b
 
22a3fc5
8ec7f9b
22a3fc5
 
8ec7f9b
22a3fc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import streamlit as st
import time
import torch
import os
from torch import autocast
from diffusers import StableDiffusionPipeline
from datasets import load_dataset
from PIL import Image
import re

st.title("Text-to-Image generation using Stable Diffusion")
st.subheader("Text Prompt")
text_prompt = st.text_area('Enter here:', height=100)

sl1, sl2, sl3, sl4 = st.columns(4)

num_samples = sl1.slider('Number of Images', 1, 4, 1)
num_steps = sl2.slider('Diffusion steps', 10, 150, 10)
scale = sl3.slider('Configuration scale', 0, 20, 10)
seed = sl4.number_input("Enter seed", 0, 25000, 47, 1)


model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda"

auth_token = os.environ.get("StableDiffusion") or True

pipe = StableDiffusionPipeline.from_pretrained(
    model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16)
pipe = pipe.to(device)
word_list_dataset = load_dataset(
    "stabilityai/word-list", data_files="list.txt", use_auth_token=auth_token)
word_list = word_list_dataset["train"]['text']


def infer(prompt, samples, steps, scale, seed):
    for filter in word_list:
        if re.search(rf"\b{filter}\b", prompt):
            raise Exception(
                "Unsafe content found. Please try again with different prompts.")

    generator = torch.Generator(device=device).manual_seed(seed)
    with autocast("cuda"):
        images_list = pipe(
            [prompt] * samples,
            num_inference_steps=steps,
            guidance_scale=scale,
            generator=generator,
        )
    images = []
    safe_image = Image.open(r"unsafe.png")
    for i, image in enumerate(images_list["sample"]):
        if (images_list["nsfw_content_detected"][i]):
            images.append(safe_image)
        else:
            images.append(image)
    return images


def check_and_infer():

    if len(text_prompt) < 5:
        st.write("Prompt too small, enter some more detail")
        st.experimental_rerun()
    else:
        with st.spinner('Wait for it...'):
            generated_images = infer(
                text_prompt, num_samples, num_steps, scale, seed)
            for image in generated_images:
                st.image(image, caption=text_prompt)
        st.success('Image generated!')
        st.balloons()


button_clicked = st.button(
    "Generate Image", on_click=check_and_infer, disabled=False)

st.markdown("""---""")

col1, col2, col3 = st.columns([1, 6, 1])

with col1:
    col1.write("")

with col2:
    placeholder = col2.empty()

    placeholder.image("pl2.png")

with col3:
    col1.write("")


for image in []:
    st.image(image, caption=text_prompt)

st.markdown("""---""")

st.text("Number of Images: Number of samples(Images) to generate")
st.text("Diffusion steps: How many steps to spend generating (diffusing) your image.")
st.text("Configuration scale: Scale adjusts how close the image will be to your prompt. Higher values keep your image closer to your prompt.")
st.text("Enter seed: Seed value to use for the model.")