yanze commited on
Commit
9084b6a
1 Parent(s): e6c5ff5

add safety checker

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  import torch
7
  from einops import rearrange
8
  from PIL import Image
 
9
 
10
  from flux.cli import SamplingOptions
11
  from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
@@ -13,6 +14,7 @@ from flux.util import load_ae, load_clip, load_flow_model, load_t5
13
  from pulid.pipeline_flux import PuLIDPipeline
14
  from pulid.utils import resize_numpy_image_long
15
 
 
16
 
17
  def get_models(name: str, device: torch.device, offload: bool):
18
  t5 = load_t5(device, max_length=128)
@@ -20,7 +22,8 @@ def get_models(name: str, device: torch.device, offload: bool):
20
  model = load_flow_model(name, device="cpu" if offload else device)
21
  model.eval()
22
  ae = load_ae(name, device="cpu" if offload else device)
23
- return model, ae, t5, clip
 
24
 
25
 
26
  class FluxGenerator:
@@ -28,7 +31,7 @@ class FluxGenerator:
28
  self.device = torch.device('cuda')
29
  self.offload = False
30
  self.model_name = 'flux-dev'
31
- self.model, self.ae, self.t5, self.clip = get_models(
32
  self.model_name,
33
  device=self.device,
34
  offload=self.offload,
@@ -147,7 +150,12 @@ def generate_image(
147
  x = rearrange(x[0], "c h w -> h w c")
148
 
149
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
150
- return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
 
 
 
 
 
151
 
152
  _HEADER_ = '''
153
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">
 
6
  import torch
7
  from einops import rearrange
8
  from PIL import Image
9
+ from transformers import pipeline
10
 
11
  from flux.cli import SamplingOptions
12
  from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack
 
14
  from pulid.pipeline_flux import PuLIDPipeline
15
  from pulid.utils import resize_numpy_image_long
16
 
17
+ NSFW_THRESHOLD = 0.85
18
 
19
  def get_models(name: str, device: torch.device, offload: bool):
20
  t5 = load_t5(device, max_length=128)
 
22
  model = load_flow_model(name, device="cpu" if offload else device)
23
  model.eval()
24
  ae = load_ae(name, device="cpu" if offload else device)
25
+ nsfw_classifier = pipeline("image-classification", model="Falconsai/nsfw_image_detection", device=device)
26
+ return model, ae, t5, clip, nsfw_classifier
27
 
28
 
29
  class FluxGenerator:
 
31
  self.device = torch.device('cuda')
32
  self.offload = False
33
  self.model_name = 'flux-dev'
34
+ self.model, self.ae, self.t5, self.clip, self.nsfw_classifier = get_models(
35
  self.model_name,
36
  device=self.device,
37
  offload=self.offload,
 
150
  x = rearrange(x[0], "c h w -> h w c")
151
 
152
  img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy())
153
+ nsfw_score = [x["score"] for x in flux_generator.nsfw_classifier(img) if x["label"] == "nsfw"][0]
154
+ if nsfw_score < NSFW_THRESHOLD:
155
+ return img, str(opts.seed), flux_generator.pulid_model.debug_img_list
156
+ else:
157
+ return (None, f"Your generated image may contain NSFW (with nsfw_score: {nsfw_score}) content",
158
+ flux_generator.pulid_model.debug_img_list)
159
 
160
  _HEADER_ = '''
161
  <div style="text-align: center; max-width: 650px; margin: 0 auto;">