import streamlit as st import time import torch import os from torch import autocast from diffusers import StableDiffusionPipeline from datasets import load_dataset from PIL import Image import re st.title("Text-to-Image generation using Stable Diffusion") st.subheader("Text Prompt") text_prompt = st.text_area('Enter here:', height=100) sl1, sl2, sl3, sl4 = st.columns(4) num_samples = sl1.slider('Number of Images', 1, 4, 1) num_steps = sl2.slider('Diffusion steps', 10, 150, 10) scale = sl3.slider('Configuration scale', 0, 20, 10) seed = sl4.number_input("Enter seed", 0, 25000, 47, 1) model_id = "CompVis/stable-diffusion-v1-4" device = "cuda" auth_token = os.environ.get("StableDiffusion") or True pipe = StableDiffusionPipeline.from_pretrained( model_id, use_auth_token=auth_token, revision="fp16", torch_dtype=torch.float16) pipe = pipe.to(device) word_list_dataset = load_dataset( "stabilityai/word-list", data_files="list.txt", use_auth_token=auth_token) word_list = word_list_dataset["train"]['text'] def infer(prompt, samples, steps, scale, seed): for filter in word_list: if re.search(rf"\b{filter}\b", prompt): raise Exception( "Unsafe content found. Please try again with different prompts.") generator = torch.Generator(device=device).manual_seed(seed) with autocast("cuda"): images_list = pipe( [prompt] * samples, num_inference_steps=steps, guidance_scale=scale, generator=generator, ) images = [] safe_image = Image.open(r"unsafe.png") for i, image in enumerate(images_list["sample"]): if (images_list["nsfw_content_detected"][i]): images.append(safe_image) else: images.append(image) return images def check_and_infer(): if len(text_prompt) < 5: st.write("Prompt too small, enter some more detail") st.experimental_rerun() else: with st.spinner('Wait for it...'): generated_images = infer( text_prompt, num_samples, num_steps, scale, seed) for image in generated_images: st.image(image, caption=text_prompt) st.success('Image generated!') st.balloons() button_clicked = st.button( "Generate Image", on_click=check_and_infer, disabled=False) st.markdown("""---""") col1, col2, col3 = st.columns([1, 6, 1]) with col1: col1.write("") with col2: placeholder = col2.empty() placeholder.image("pl2.png") with col3: col1.write("") for image in []: st.image(image, caption=text_prompt) st.markdown("""---""") st.text("Number of Images: Number of samples(Images) to generate") st.text("Diffusion steps: How many steps to spend generating (diffusing) your image.") st.text("Configuration scale: Scale adjusts how close the image will be to your prompt. Higher values keep your image closer to your prompt.") st.text("Enter seed: Seed value to use for the model.")