radames commited on
Commit
cb92d2b
1 Parent(s): ee4d659
.gitignore CHANGED
@@ -1,2 +1,4 @@
1
  __pycache__/
2
- venv/
 
 
 
1
  __pycache__/
2
+ venv/
3
+ public/
4
+ *.pem
app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+
3
+ from config import args
4
+ from device import device, torch_dtype
5
+ from app_init import init_app
6
+ from user_queue import user_queue_map
7
+ from util import get_pipeline_class
8
+
9
+
10
+ app = FastAPI()
11
+
12
+ pipeline_class = get_pipeline_class(args.pipeline)
13
+ pipeline = pipeline_class(args, device, torch_dtype)
14
+ init_app(app, user_queue_map, args, pipeline)
app_init.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, WebSocket, HTTPException, WebSocketDisconnect
2
+ from fastapi.responses import StreamingResponse, JSONResponse
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.staticfiles import StaticFiles
5
+
6
+ import logging
7
+ import traceback
8
+ from config import Args
9
+ from user_queue import UserQueueDict
10
+ import uuid
11
+ import asyncio
12
+ import time
13
+ from PIL import Image
14
+ import io
15
+
16
+
17
+ def init_app(app: FastAPI, user_queue_map: UserQueueDict, args: Args, pipeline):
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+ print("Init app", app)
26
+
27
+ @app.websocket("/ws")
28
+ async def websocket_endpoint(websocket: WebSocket):
29
+ await websocket.accept()
30
+ if args.max_queue_size > 0 and len(user_queue_map) >= args.max_queue_size:
31
+ print("Server is full")
32
+ await websocket.send_json({"status": "error", "message": "Server is full"})
33
+ await websocket.close()
34
+ return
35
+
36
+ try:
37
+ uid = uuid.uuid4()
38
+ print(f"New user connected: {uid}")
39
+ await websocket.send_json(
40
+ {"status": "success", "message": "Connected", "userId": uid}
41
+ )
42
+ user_queue_map[uid] = {"queue": asyncio.Queue()}
43
+ await websocket.send_json(
44
+ {"status": "start", "message": "Start Streaming", "userId": uid}
45
+ )
46
+ await handle_websocket_data(websocket, uid)
47
+ except WebSocketDisconnect as e:
48
+ logging.error(f"WebSocket Error: {e}, {uid}")
49
+ traceback.print_exc()
50
+ finally:
51
+ print(f"User disconnected: {uid}")
52
+ queue_value = user_queue_map.pop(uid, None)
53
+ queue = queue_value.get("queue", None)
54
+ if queue:
55
+ while not queue.empty():
56
+ try:
57
+ queue.get_nowait()
58
+ except asyncio.QueueEmpty:
59
+ continue
60
+
61
+ @app.get("/queue_size")
62
+ async def get_queue_size():
63
+ queue_size = len(user_queue_map)
64
+ return JSONResponse({"queue_size": queue_size})
65
+
66
+ @app.get("/stream/{user_id}")
67
+ async def stream(user_id: uuid.UUID):
68
+ uid = user_id
69
+ try:
70
+ user_queue = user_queue_map[uid]
71
+ queue = user_queue["queue"]
72
+
73
+ async def generate():
74
+ last_prompt: str = None
75
+ while True:
76
+ data = await queue.get()
77
+ input_image = data["image"]
78
+ params = data["params"]
79
+ if input_image is None:
80
+ continue
81
+
82
+ image = pipeline.predict(
83
+ input_image,
84
+ params,
85
+ )
86
+ if image is None:
87
+ continue
88
+ frame_data = io.BytesIO()
89
+ image.save(frame_data, format="JPEG")
90
+ frame_data = frame_data.getvalue()
91
+ if frame_data is not None and len(frame_data) > 0:
92
+ yield b"--frame\r\nContent-Type: image/jpeg\r\n\r\n" + frame_data + b"\r\n"
93
+
94
+ await asyncio.sleep(1.0 / 120.0)
95
+
96
+ return StreamingResponse(
97
+ generate(), media_type="multipart/x-mixed-replace;boundary=frame"
98
+ )
99
+ except Exception as e:
100
+ logging.error(f"Streaming Error: {e}, {user_queue_map}")
101
+ traceback.print_exc()
102
+ return HTTPException(status_code=404, detail="User not found")
103
+
104
+ async def handle_websocket_data(websocket: WebSocket, user_id: uuid.UUID):
105
+ uid = user_id
106
+ user_queue = user_queue_map[uid]
107
+ queue = user_queue["queue"]
108
+ if not queue:
109
+ return HTTPException(status_code=404, detail="User not found")
110
+ last_time = time.time()
111
+ try:
112
+ while True:
113
+ data = await websocket.receive_bytes()
114
+ params = await websocket.receive_json()
115
+ params = pipeline.InputParams(**params)
116
+ pil_image = Image.open(io.BytesIO(data))
117
+
118
+ while not queue.empty():
119
+ try:
120
+ queue.get_nowait()
121
+ except asyncio.QueueEmpty:
122
+ continue
123
+ await queue.put({"image": pil_image, "params": params})
124
+ if args.timeout > 0 and time.time() - last_time > args.timeout:
125
+ await websocket.send_json(
126
+ {
127
+ "status": "timeout",
128
+ "message": "Your session has ended",
129
+ "userId": uid,
130
+ }
131
+ )
132
+ await websocket.close()
133
+ return
134
+
135
+ except Exception as e:
136
+ logging.error(f"Error: {e}")
137
+ traceback.print_exc()
138
+
139
+ # route to setup frontend
140
+ @app.get("/settings")
141
+ async def settings():
142
+ params = pipeline.InputParams()
143
+ return JSONResponse({"settings": params.dict()})
144
+
145
+ app.mount("/", StaticFiles(directory="public", html=True), name="public")
build-run.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ cd frontend
3
+ npm install
4
+ npm run build
5
+ if [ $? -eq 0 ]; then
6
+ echo -e "\033[1;32m\nfrontend build success \033[0m"
7
+ else
8
+ echo -e "\033[1;31m\nfrontend build failed\n\033[0m" >&2 exit 1
9
+ fi
10
+ cd ../
11
+ python run.py --reload
12
+
config.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import NamedTuple
2
+ import argparse
3
+ import os
4
+
5
+
6
+ class Args(NamedTuple):
7
+ host: str
8
+ port: int
9
+ reload: bool
10
+ mode: str
11
+ max_queue_size: int
12
+ timeout: float
13
+ safety_checker: bool
14
+ torch_compile: bool
15
+ use_taesd: bool
16
+ pipeline: str
17
+
18
+
19
+ MAX_QUEUE_SIZE = int(os.environ.get("MAX_QUEUE_SIZE", 0))
20
+ TIMEOUT = float(os.environ.get("TIMEOUT", 0))
21
+ SAFETY_CHECKER = os.environ.get("SAFETY_CHECKER", None) == "True"
22
+ TORCH_COMPILE = os.environ.get("TORCH_COMPILE", None) == "True"
23
+ USE_TAESD = os.environ.get("USE_TAESD", None) == "True"
24
+ default_host = os.getenv("HOST", "0.0.0.0")
25
+ default_port = int(os.getenv("PORT", "7860"))
26
+ default_mode = os.getenv("MODE", "default")
27
+
28
+ parser = argparse.ArgumentParser(description="Run the app")
29
+ parser.add_argument("--host", type=str, default=default_host, help="Host address")
30
+ parser.add_argument("--port", type=int, default=default_port, help="Port number")
31
+ parser.add_argument("--reload", action="store_true", help="Reload code on change")
32
+ parser.add_argument(
33
+ "--mode", type=str, default=default_mode, help="App Inferece Mode: txt2img, img2img"
34
+ )
35
+ parser.add_argument(
36
+ "--max_queue_size", type=int, default=MAX_QUEUE_SIZE, help="Max Queue Size"
37
+ )
38
+ parser.add_argument("--timeout", type=float, default=TIMEOUT, help="Timeout")
39
+ parser.add_argument(
40
+ "--safety_checker", type=bool, default=SAFETY_CHECKER, help="Safety Checker"
41
+ )
42
+ parser.add_argument(
43
+ "--torch_compile", type=bool, default=TORCH_COMPILE, help="Torch Compile"
44
+ )
45
+ parser.add_argument(
46
+ "--use_taesd",
47
+ type=bool,
48
+ default=USE_TAESD,
49
+ help="Use Tiny Autoencoder",
50
+ )
51
+ parser.add_argument(
52
+ "--pipeline",
53
+ type=str,
54
+ default="txt2img",
55
+ help="Pipeline to use",
56
+ )
57
+
58
+ args = Args(**vars(parser.parse_args()))
device.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ # check if MPS is available OSX only M1/M2/M3 chips
4
+ mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
5
+ xpu_available = hasattr(torch, "xpu") and torch.xpu.is_available()
6
+ device = torch.device(
7
+ "cuda" if torch.cuda.is_available() else "xpu" if xpu_available else "cpu"
8
+ )
9
+ torch_dtype = torch.float16
10
+ if mps_available:
11
+ device = torch.device("mps")
12
+ torch_dtype = torch.float32
frontend/.eslintignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+
10
+ # Ignore files for PNPM, NPM and YARN
11
+ pnpm-lock.yaml
12
+ package-lock.json
13
+ yarn.lock
frontend/.eslintrc.cjs ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ module.exports = {
2
+ root: true,
3
+ extends: [
4
+ 'eslint:recommended',
5
+ 'plugin:@typescript-eslint/recommended',
6
+ 'plugin:svelte/recommended',
7
+ 'prettier'
8
+ ],
9
+ parser: '@typescript-eslint/parser',
10
+ plugins: ['@typescript-eslint'],
11
+ parserOptions: {
12
+ sourceType: 'module',
13
+ ecmaVersion: 2020,
14
+ extraFileExtensions: ['.svelte']
15
+ },
16
+ env: {
17
+ browser: true,
18
+ es2017: true,
19
+ node: true
20
+ },
21
+ overrides: [
22
+ {
23
+ files: ['*.svelte'],
24
+ parser: 'svelte-eslint-parser',
25
+ parserOptions: {
26
+ parser: '@typescript-eslint/parser'
27
+ }
28
+ }
29
+ ]
30
+ };
frontend/.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+ vite.config.js.timestamp-*
10
+ vite.config.ts.timestamp-*
frontend/.npmrc ADDED
@@ -0,0 +1 @@
 
 
1
+ engine-strict=true
frontend/.prettierignore ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .DS_Store
2
+ node_modules
3
+ /build
4
+ /.svelte-kit
5
+ /package
6
+ .env
7
+ .env.*
8
+ !.env.example
9
+
10
+ # Ignore files for PNPM, NPM and YARN
11
+ pnpm-lock.yaml
12
+ package-lock.json
13
+ yarn.lock
frontend/.prettierrc ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "useTabs": false,
3
+ "singleQuote": true,
4
+ "trailingComma": "none",
5
+ "printWidth": 100,
6
+ "plugins": [
7
+ "prettier-plugin-svelte",
8
+ "prettier-plugin-organize-imports",
9
+ "prettier-plugin-tailwindcss"
10
+ ],
11
+ "overrides": [
12
+ {
13
+ "files": "*.svelte",
14
+ "options": {
15
+ "parser": "svelte"
16
+ }
17
+ }
18
+ ]
19
+ }
frontend/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # create-svelte
2
+
3
+ Everything you need to build a Svelte project, powered by [`create-svelte`](https://github.com/sveltejs/kit/tree/master/packages/create-svelte).
4
+
5
+ ## Creating a project
6
+
7
+ If you're seeing this, you've probably already done this step. Congrats!
8
+
9
+ ```bash
10
+ # create a new project in the current directory
11
+ npm create svelte@latest
12
+
13
+ # create a new project in my-app
14
+ npm create svelte@latest my-app
15
+ ```
16
+
17
+ ## Developing
18
+
19
+ Once you've created a project and installed dependencies with `npm install` (or `pnpm install` or `yarn`), start a development server:
20
+
21
+ ```bash
22
+ npm run dev
23
+
24
+ # or start the server and open the app in a new browser tab
25
+ npm run dev -- --open
26
+ ```
27
+
28
+ ## Building
29
+
30
+ To create a production version of your app:
31
+
32
+ ```bash
33
+ npm run build
34
+ ```
35
+
36
+ You can preview the production build with `npm run preview`.
37
+
38
+ > To deploy your app, you may need to install an [adapter](https://kit.svelte.dev/docs/adapters) for your target environment.
frontend/package-lock.json ADDED
The diff for this file is too large to render. See raw diff
 
frontend/package.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "frontend",
3
+ "version": "0.0.1",
4
+ "private": true,
5
+ "scripts": {
6
+ "dev": "vite dev",
7
+ "build": "vite build",
8
+ "preview": "vite preview",
9
+ "check": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json",
10
+ "check:watch": "svelte-kit sync && svelte-check --tsconfig ./tsconfig.json --watch",
11
+ "lint": "prettier --check . && eslint .",
12
+ "format": "prettier --write ."
13
+ },
14
+ "devDependencies": {
15
+ "@sveltejs/adapter-auto": "^2.0.0",
16
+ "@sveltejs/kit": "^1.20.4",
17
+ "@typescript-eslint/eslint-plugin": "^6.0.0",
18
+ "@typescript-eslint/parser": "^6.0.0",
19
+ "autoprefixer": "^10.4.16",
20
+ "eslint": "^8.28.0",
21
+ "eslint-config-prettier": "^9.0.0",
22
+ "eslint-plugin-svelte": "^2.30.0",
23
+ "postcss": "^8.4.31",
24
+ "prettier": "^3.1.0",
25
+ "prettier-plugin-organize-imports": "^3.2.4",
26
+ "prettier-plugin-svelte": "^3.1.0",
27
+ "prettier-plugin-tailwindcss": "^0.5.7",
28
+ "svelte": "^4.0.5",
29
+ "svelte-check": "^3.4.3",
30
+ "tailwindcss": "^3.3.5",
31
+ "tslib": "^2.4.1",
32
+ "typescript": "^5.0.0",
33
+ "vite": "^4.4.2"
34
+ },
35
+ "type": "module"
36
+ }
frontend/postcss.config.js ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ export default {
2
+ plugins: {
3
+ tailwindcss: {},
4
+ autoprefixer: {}
5
+ }
6
+ };
frontend/src/app.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ @tailwind base;
2
+ @tailwind components;
3
+ @tailwind utilities;
frontend/src/app.d.ts ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // See https://kit.svelte.dev/docs/types#app
2
+ // for information about these interfaces
3
+ declare global {
4
+ namespace App {
5
+ // interface Error {}
6
+ // interface Locals {}
7
+ // interface PageData {}
8
+ // interface Platform {}
9
+ }
10
+ }
11
+
12
+ export {};
frontend/src/app.html ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!doctype html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="utf-8" />
5
+ <link rel="icon" href="%sveltekit.assets%/favicon.png" />
6
+ <meta name="viewport" content="width=device-width, initial-scale=1" />
7
+ %sveltekit.head%
8
+ </head>
9
+ <body data-sveltekit-preload-data="hover">
10
+ <div style="display: contents">%sveltekit.body%</div>
11
+ </body>
12
+ </html>
frontend/src/lib/index.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ // place files you want to import through the `$lib` alias in this folder.
frontend/src/lib/types.ts ADDED
File without changes
frontend/src/routes/+layout.svelte ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ <script>
2
+ import '../app.css';
3
+ </script>
4
+
5
+ <slot />
frontend/src/routes/+page.svelte ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <script lang="ts">
2
+ import { onMount } from 'svelte';
3
+ import { PUBLIC_BASE_URL } from '$env/static/public';
4
+
5
+ onMount(() => {
6
+ getSettings();
7
+ });
8
+ async function getSettings() {
9
+ const settings = await fetch(`${PUBLIC_BASE_URL}/settings`).then((r) => r.json());
10
+ console.log(settings);
11
+ }
12
+ </script>
13
+
14
+ <div class="fixed right-2 top-2 max-w-xs rounded-lg p-4 text-center text-sm font-bold" id="error" />
15
+ <main class="container mx-auto flex max-w-4xl flex-col gap-4 px-4 py-4">
16
+ <article class="mx-auto max-w-xl text-center">
17
+ <h1 class="text-3xl font-bold">Real-Time Latent Consistency Model</h1>
18
+ <h2 class="mb-4 text-2xl font-bold">Image to Image</h2>
19
+ <p class="text-sm">
20
+ This demo showcases
21
+ <a
22
+ href="https://huggingface.co/SimianLuo/LCM_Dreamshaper_v7"
23
+ target="_blank"
24
+ class="text-blue-500 underline hover:no-underline">LCM</a
25
+ >
26
+ Image to Image pipeline using
27
+ <a
28
+ href="https://github.com/huggingface/diffusers/tree/main/examples/community#latent-consistency-pipeline"
29
+ target="_blank"
30
+ class="text-blue-500 underline hover:no-underline">Diffusers</a
31
+ > with a MJPEG stream server.
32
+ </p>
33
+ <p class="text-sm">
34
+ There are <span id="queue_size" class="font-bold">0</span> user(s) sharing the same GPU,
35
+ affecting real-time performance. Maximum queue size is 4.
36
+ <a
37
+ href="https://huggingface.co/spaces/radames/Real-Time-Latent-Consistency-Model?duplicate=true"
38
+ target="_blank"
39
+ class="text-blue-500 underline hover:no-underline">Duplicate</a
40
+ > and run it on your own GPU.
41
+ </p>
42
+ </article>
43
+ <div>
44
+ <h2 class="font-medium">Prompt</h2>
45
+ <p class="text-sm text-gray-500">
46
+ Change the prompt to generate different images, accepts <a
47
+ href="https://github.com/damian0815/compel/blob/main/doc/syntax.md"
48
+ target="_blank"
49
+ class="text-blue-500 underline hover:no-underline">Compel</a
50
+ > syntax.
51
+ </p>
52
+ <div class="text-normal flex items-center rounded-md border border-gray-700 px-1 py-1">
53
+ <textarea
54
+ type="text"
55
+ id="prompt"
56
+ class="mx-1 w-full px-3 py-2 font-light outline-none dark:text-black"
57
+ title="Prompt, this is an example, feel free to modify"
58
+ placeholder="Add your prompt here..."
59
+ >Portrait of The Terminator with , glare pose, detailed, intricate, full of colour,
60
+ cinematic lighting, trending on artstation, 8k, hyperrealistic, focused, extreme details,
61
+ unreal engine 5, cinematic, masterpiece</textarea
62
+ >
63
+ </div>
64
+ </div>
65
+ <div class="">
66
+ <details>
67
+ <summary class="cursor-pointer font-medium">Advanced Options</summary>
68
+ <div class="grid max-w-md grid-cols-3 items-center gap-3 py-3">
69
+ <label class="text-sm font-medium" for="guidance-scale">Guidance Scale </label>
70
+ <input
71
+ type="range"
72
+ id="guidance-scale"
73
+ name="guidance-scale"
74
+ min="1"
75
+ max="30"
76
+ step="0.001"
77
+ value="8.0"
78
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
79
+ />
80
+ <output
81
+ class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
82
+ >
83
+ 8.0</output
84
+ >
85
+ <label class="text-sm font-medium" for="strength">Strength</label>
86
+ <input
87
+ type="range"
88
+ id="strength"
89
+ name="strength"
90
+ min="0.20"
91
+ max="1"
92
+ step="0.001"
93
+ value="0.50"
94
+ oninput="this.nextElementSibling.value = Number(this.value).toFixed(2)"
95
+ />
96
+ <output
97
+ class="w-[50px] rounded-md border border-gray-700 px-1 py-1 text-center text-xs font-light"
98
+ >
99
+ 0.5</output
100
+ >
101
+ <label class="text-sm font-medium" for="seed">Seed</label>
102
+ <input
103
+ type="number"
104
+ id="seed"
105
+ name="seed"
106
+ value="299792458"
107
+ class="rounded-md border border-gray-700 p-2 text-right font-light dark:text-black"
108
+ />
109
+ <button
110
+ onclick="document.querySelector('#seed').value = Math.floor(Math.random() * Number.MAX_SAFE_INTEGER)"
111
+ class="button"
112
+ >
113
+ Rand
114
+ </button>
115
+ </div>
116
+ </details>
117
+ </div>
118
+ <div class="flex gap-3">
119
+ <button id="start" class="button"> Start </button>
120
+ <button id="stop" class="button"> Stop </button>
121
+ <button id="snap" disabled class="button ml-auto"> Snapshot </button>
122
+ </div>
123
+ <div class="relative overflow-hidden rounded-lg border border-slate-300">
124
+ <img
125
+ id="player"
126
+ class="aspect-square w-full rounded-lg"
127
+ src="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNkYAAAAAYAAjCB0C8AAAAASUVORK5CYII="
128
+ />
129
+ <div class="absolute left-0 top-0 aspect-square w-1/4">
130
+ <video
131
+ id="webcam"
132
+ class="relative z-10 aspect-square w-full object-cover"
133
+ playsinline
134
+ autoplay
135
+ muted
136
+ loop
137
+ />
138
+ <svg
139
+ xmlns="http://www.w3.org/2000/svg"
140
+ viewBox="0 0 448 448"
141
+ width="100"
142
+ class="absolute top-0 z-0 w-full p-4 opacity-20"
143
+ >
144
+ <path
145
+ fill="currentColor"
146
+ d="M224 256a128 128 0 1 0 0-256 128 128 0 1 0 0 256zm-45.7 48A178.3 178.3 0 0 0 0 482.3 29.7 29.7 0 0 0 29.7 512h388.6a29.7 29.7 0 0 0 29.7-29.7c0-98.5-79.8-178.3-178.3-178.3h-91.4z"
147
+ />
148
+ </svg>
149
+ </div>
150
+ </div>
151
+ </main>
152
+
153
+ <style lang="postcss">
154
+ :global(html) {
155
+ @apply text-black dark:bg-gray-900 dark:text-white;
156
+ }
157
+ .button {
158
+ @apply rounded bg-gray-700 p-2 font-normal text-white hover:bg-gray-800 disabled:cursor-not-allowed disabled:bg-gray-300 dark:disabled:bg-gray-700 dark:disabled:text-black;
159
+ }
160
+ </style>
frontend/src/routes/+page.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ export const prerender = true
frontend/static/favicon.png ADDED
frontend/svelte.config.js ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import adapter from '@sveltejs/adapter-static';
2
+ import { vitePreprocess } from '@sveltejs/kit/vite';
3
+
4
+ /** @type {import('@sveltejs/kit').Config} */
5
+ const config = {
6
+ preprocess: vitePreprocess(),
7
+
8
+ kit: {
9
+ adapter: adapter({
10
+ pages: '../public',
11
+ assets: '../public',
12
+ fallback: undefined,
13
+ precompress: false,
14
+ strict: true
15
+ })
16
+ }
17
+ };
18
+
19
+ export default config;
frontend/tailwind.config.js ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ /** @type {import('tailwindcss').Config} */
2
+ export default {
3
+ content: ['./src/**/*.{html,js,svelte,ts}'],
4
+ theme: {
5
+ extend: {}
6
+ },
7
+ plugins: []
8
+ };
frontend/tsconfig.json ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "extends": "./.svelte-kit/tsconfig.json",
3
+ "compilerOptions": {
4
+ "allowJs": true,
5
+ "checkJs": true,
6
+ "esModuleInterop": true,
7
+ "forceConsistentCasingInFileNames": true,
8
+ "resolveJsonModule": true,
9
+ "skipLibCheck": true,
10
+ "sourceMap": true,
11
+ "strict": true
12
+ }
13
+ // Path aliases are handled by https://kit.svelte.dev/docs/configuration#alias
14
+ //
15
+ // If you want to overwrite includes/excludes, make sure to copy over the relevant includes/excludes
16
+ // from the referenced tsconfig.json - TypeScript does not merge them in
17
+ }
frontend/vite.config.ts ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import { sveltekit } from '@sveltejs/kit/vite';
2
+ import { defineConfig } from 'vite';
3
+
4
+ export default defineConfig({
5
+ plugins: [sveltekit()]
6
+ });
pipelines/__init__.py ADDED
File without changes
pipelines/controlnet.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline, AutoencoderTiny
2
+ from latent_consistency_controlnet import LatentConsistencyModelPipeline_controlnet
3
+
4
+ from compel import Compel
5
+ import torch
6
+
7
+ try:
8
+ import intel_extension_for_pytorch as ipex # type: ignore
9
+ except:
10
+ pass
11
+
12
+ import psutil
13
+ from config import Args
14
+ from pydantic import BaseModel
15
+ from PIL import Image
16
+ from typing import Callable
17
+
18
+ base_model = "SimianLuo/LCM_Dreamshaper_v7"
19
+ WIDTH = 512
20
+ HEIGHT = 512
21
+
22
+
23
+ class Pipeline:
24
+ class InputParams(BaseModel):
25
+ seed: int = 2159232
26
+ prompt: str
27
+ guidance_scale: float = 8.0
28
+ strength: float = 0.5
29
+ steps: int = 4
30
+ lcm_steps: int = 50
31
+ width: int = WIDTH
32
+ height: int = HEIGHT
33
+
34
+ @staticmethod
35
+ def create_pipeline(
36
+ args: Args, device: torch.device, torch_dtype: torch.dtype
37
+ ) -> Callable[["Pipeline.InputParams"], Image.Image]:
38
+ if args.safety_checker:
39
+ pipe = DiffusionPipeline.from_pretrained(base_model)
40
+ else:
41
+ pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
42
+ if args.use_taesd:
43
+ pipe.vae = AutoencoderTiny.from_pretrained(
44
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
45
+ )
46
+
47
+ pipe.set_progress_bar_config(disable=True)
48
+ pipe.to(device=device, dtype=torch_dtype)
49
+ pipe.unet.to(memory_format=torch.channels_last)
50
+
51
+ # check if computer has less than 64GB of RAM using sys or os
52
+ if psutil.virtual_memory().total < 64 * 1024**3:
53
+ pipe.enable_attention_slicing()
54
+
55
+ if args.torch_compile:
56
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
57
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
58
+
59
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
60
+
61
+ compel_proc = Compel(
62
+ tokenizer=pipe.tokenizer,
63
+ text_encoder=pipe.text_encoder,
64
+ truncate_long_prompts=False,
65
+ )
66
+
67
+ def predict(params: "Pipeline.InputParams") -> Image.Image:
68
+ generator = torch.manual_seed(params.seed)
69
+ prompt_embeds = compel_proc(params.prompt)
70
+ # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
71
+ results = pipe(
72
+ prompt_embeds=prompt_embeds,
73
+ generator=generator,
74
+ num_inference_steps=params.steps,
75
+ guidance_scale=params.guidance_scale,
76
+ width=params.width,
77
+ height=params.height,
78
+ original_inference_steps=params.lcm_steps,
79
+ output_type="pil",
80
+ )
81
+ nsfw_content_detected = (
82
+ results.nsfw_content_detected[0]
83
+ if "nsfw_content_detected" in results
84
+ else False
85
+ )
86
+ if nsfw_content_detected:
87
+ return None
88
+ return results.images[0]
89
+
90
+ return predict
pipelines/txt2img.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline, AutoencoderTiny
2
+ from compel import Compel
3
+ import torch
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex # type: ignore
7
+ except:
8
+ pass
9
+
10
+ import psutil
11
+ from config import Args
12
+ from pydantic import BaseModel
13
+ from PIL import Image
14
+ from typing import Callable
15
+
16
+ base_model = "SimianLuo/LCM_Dreamshaper_v7"
17
+ taesd_model = "madebyollin/taesd"
18
+
19
+
20
+ class Pipeline:
21
+ class InputParams(BaseModel):
22
+ seed: int = 2159232
23
+ prompt: str = ""
24
+ guidance_scale: float = 8.0
25
+ strength: float = 0.5
26
+ steps: int = 4
27
+ width: int = 512
28
+ height: int = 512
29
+
30
+ def __init__(self, args: Args, device: torch.device, torch_dtype: torch.dtype):
31
+ if args.safety_checker:
32
+ self.pipe = DiffusionPipeline.from_pretrained(base_model)
33
+ else:
34
+ self.pipe = DiffusionPipeline.from_pretrained(
35
+ base_model, safety_checker=None
36
+ )
37
+ if args.use_taesd:
38
+ self.pipe.vae = AutoencoderTiny.from_pretrained(
39
+ taesd_model, torch_dtype=torch_dtype, use_safetensors=True
40
+ )
41
+
42
+ self.pipe.set_progress_bar_config(disable=True)
43
+ self.pipe.to(device=device, dtype=torch_dtype)
44
+ self.pipe.unet.to(memory_format=torch.channels_last)
45
+
46
+ # check if computer has less than 64GB of RAM using sys or os
47
+ if psutil.virtual_memory().total < 64 * 1024**3:
48
+ self.pipe.enable_attention_slicing()
49
+
50
+ if args.torch_compile:
51
+ self.pipe.unet = torch.compile(
52
+ self.pipe.unet, mode="reduce-overhead", fullgraph=True
53
+ )
54
+ self.pipe.vae = torch.compile(
55
+ self.pipe.vae, mode="reduce-overhead", fullgraph=True
56
+ )
57
+
58
+ self.pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
59
+
60
+ self.compel_proc = Compel(
61
+ tokenizer=self.pipe.tokenizer,
62
+ text_encoder=self.pipe.text_encoder,
63
+ truncate_long_prompts=False,
64
+ )
65
+
66
+ def predict(self, params: "Pipeline.InputParams") -> Image.Image:
67
+ generator = torch.manual_seed(params.seed)
68
+ prompt_embeds = self.compel_proc(params.prompt)
69
+ results = self.pipe(
70
+ prompt_embeds=prompt_embeds,
71
+ generator=generator,
72
+ num_inference_steps=params.steps,
73
+ guidance_scale=params.guidance_scale,
74
+ width=params.width,
75
+ height=params.height,
76
+ output_type="pil",
77
+ )
78
+ nsfw_content_detected = (
79
+ results.nsfw_content_detected[0]
80
+ if "nsfw_content_detected" in results
81
+ else False
82
+ )
83
+ if nsfw_content_detected:
84
+ return None
85
+ return results.images[0]
pipelines/txt2imglora.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import DiffusionPipeline, AutoencoderTiny
2
+ from compel import Compel
3
+ import torch
4
+
5
+ try:
6
+ import intel_extension_for_pytorch as ipex # type: ignore
7
+ except:
8
+ pass
9
+
10
+ import psutil
11
+ from config import Args
12
+ from pydantic import BaseModel
13
+ from PIL import Image
14
+ from typing import Callable
15
+
16
+ base_model = "SimianLuo/LCM_Dreamshaper_v7"
17
+ WIDTH = 512
18
+ HEIGHT = 512
19
+
20
+ model_id = "wavymulder/Analog-Diffusion"
21
+ lcm_lora_id = "latent-consistency/lcm-lora-sdv1-5"
22
+
23
+
24
+ class Pipeline:
25
+ class InputParams(BaseModel):
26
+ seed: int = 2159232
27
+ prompt: str
28
+ guidance_scale: float = 8.0
29
+ strength: float = 0.5
30
+ steps: int = 4
31
+ lcm_steps: int = 50
32
+ width: int = WIDTH
33
+ height: int = HEIGHT
34
+
35
+ @staticmethod
36
+ def create_pipeline(
37
+ args: Args, device: torch.device, torch_dtype: torch.dtype
38
+ ) -> Callable[["Pipeline.InputParams"], Image.Image]:
39
+ if args.safety_checker:
40
+ pipe = DiffusionPipeline.from_pretrained(base_model)
41
+ else:
42
+ pipe = DiffusionPipeline.from_pretrained(base_model, safety_checker=None)
43
+ if args.use_taesd:
44
+ pipe.vae = AutoencoderTiny.from_pretrained(
45
+ "madebyollin/taesd", torch_dtype=torch_dtype, use_safetensors=True
46
+ )
47
+
48
+ pipe.set_progress_bar_config(disable=True)
49
+ pipe.to(device=device, dtype=torch_dtype)
50
+ pipe.unet.to(memory_format=torch.channels_last)
51
+
52
+ # Load LCM LoRA
53
+ pipe.load_lora_weights(lcm_lora_id, adapter_name="lcm")
54
+ # check if computer has less than 64GB of RAM using sys or os
55
+ if psutil.virtual_memory().total < 64 * 1024**3:
56
+ pipe.enable_attention_slicing()
57
+
58
+ if args.torch_compile:
59
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
60
+ pipe.vae = torch.compile(pipe.vae, mode="reduce-overhead", fullgraph=True)
61
+
62
+ pipe(prompt="warmup", num_inference_steps=1, guidance_scale=8.0)
63
+
64
+ compel_proc = Compel(
65
+ tokenizer=pipe.tokenizer,
66
+ text_encoder=pipe.text_encoder,
67
+ truncate_long_prompts=False,
68
+ )
69
+
70
+ def predict(params: "Pipeline.InputParams") -> Image.Image:
71
+ generator = torch.manual_seed(params.seed)
72
+ prompt_embeds = compel_proc(params.prompt)
73
+ # Can be set to 1~50 steps. LCM support fast inference even <= 4 steps. Recommend: 1~8 steps.
74
+ results = pipe(
75
+ prompt_embeds=prompt_embeds,
76
+ generator=generator,
77
+ num_inference_steps=params.steps,
78
+ guidance_scale=params.guidance_scale,
79
+ width=params.width,
80
+ height=params.height,
81
+ original_inference_steps=params.lcm_steps,
82
+ output_type="pil",
83
+ )
84
+ nsfw_content_detected = (
85
+ results.nsfw_content_detected[0]
86
+ if "nsfw_content_detected" in results
87
+ else False
88
+ )
89
+ if nsfw_content_detected:
90
+ return None
91
+ return results.images[0]
92
+
93
+ return predict
requirements.txt CHANGED
@@ -3,8 +3,8 @@ transformers==4.34.1
3
  gradio==3.50.2
4
  --extra-index-url https://download.pytorch.org/whl/cu121;
5
  torch==2.1.0
6
- fastapi==0.104.0
7
- uvicorn==0.23.2
8
  Pillow==10.1.0
9
  accelerate==0.24.0
10
  compel==2.0.2
 
3
  gradio==3.50.2
4
  --extra-index-url https://download.pytorch.org/whl/cu121;
5
  torch==2.1.0
6
+ fastapi==0.104.1
7
+ uvicorn==0.24.0.post1
8
  Pillow==10.1.0
9
  accelerate==0.24.0
10
  compel==2.0.2
run.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ if __name__ == "__main__":
2
+ import uvicorn
3
+ from config import args
4
+
5
+ uvicorn.run("app:app", host=args.host, port=args.port, reload=args.reload)
user_queue.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Union
2
+ from uuid import UUID
3
+ from asyncio import Queue
4
+ from PIL import Image
5
+ from typing import Tuple, Union
6
+ from uuid import UUID
7
+ from asyncio import Queue
8
+ from PIL import Image
9
+
10
+ UserId = UUID
11
+
12
+ InputParams = dict
13
+
14
+ QueueContent = Dict[str, Union[Image.Image, InputParams]]
15
+
16
+ UserQueueDict = Dict[UserId, Queue[QueueContent]]
17
+
18
+ user_queue_map: UserQueueDict = {}
util.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from importlib import import_module
2
+ from types import ModuleType
3
+
4
+
5
+ def get_pipeline_class(pipeline_name: str) -> ModuleType:
6
+ try:
7
+ module = import_module(f"pipelines.{pipeline_name}")
8
+ except ModuleNotFoundError:
9
+ raise ValueError(f"Pipeline {pipeline_name} module not found")
10
+
11
+ pipeline_class = getattr(module, "Pipeline", None)
12
+
13
+ if pipeline_class is None:
14
+ raise ValueError(f"'Pipeline' class not found in module '{module_name}'.")
15
+
16
+ return pipeline_class