Spaces:
Running
on
A10G
Running
on
A10G
Linoy Tsaban
commited on
Commit
•
eec6d5e
1
Parent(s):
2aeea2e
Update modified_pipeline_semantic_stable_diffusion.py
Browse files
modified_pipeline_semantic_stable_diffusion.py
CHANGED
@@ -721,37 +721,37 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline):
|
|
721 |
callback(i, t, latents)
|
722 |
|
723 |
|
724 |
-
|
725 |
-
|
726 |
|
727 |
-
#
|
728 |
-
|
729 |
|
730 |
-
#
|
731 |
-
|
732 |
-
|
733 |
|
734 |
-
|
735 |
-
|
736 |
|
737 |
-
|
738 |
|
739 |
-
# 8. Post-processing
|
740 |
-
if not output_type == "latent":
|
741 |
-
|
742 |
-
|
743 |
-
else:
|
744 |
-
|
745 |
-
|
746 |
|
747 |
-
if has_nsfw_concept is None:
|
748 |
-
|
749 |
-
else:
|
750 |
-
|
751 |
|
752 |
-
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
753 |
|
754 |
-
if not return_dict:
|
755 |
-
|
756 |
|
757 |
-
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
|
721 |
callback(i, t, latents)
|
722 |
|
723 |
|
724 |
+
# 8. Post-processing
|
725 |
+
image = self.decode_latents(latents)
|
726 |
|
727 |
+
# 9. Run safety checker
|
728 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
|
729 |
|
730 |
+
# 10. Convert to PIL
|
731 |
+
if output_type == "pil":
|
732 |
+
image = self.numpy_to_pil(image)
|
733 |
|
734 |
+
if not return_dict:
|
735 |
+
return (image, has_nsfw_concept)
|
736 |
|
737 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
738 |
|
739 |
+
# # 8. Post-processing
|
740 |
+
# if not output_type == "latent":
|
741 |
+
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
742 |
+
# image, has_nsfw_concept = self.run_safety_checker(image, self.device, text_embeddings.dtype)
|
743 |
+
# else:
|
744 |
+
# image = latents
|
745 |
+
# has_nsfw_concept = None
|
746 |
|
747 |
+
# if has_nsfw_concept is None:
|
748 |
+
# do_denormalize = [True] * image.shape[0]
|
749 |
+
# else:
|
750 |
+
# do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
751 |
|
752 |
+
# image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
753 |
|
754 |
+
# if not return_dict:
|
755 |
+
# return (image, has_nsfw_concept)
|
756 |
|
757 |
+
# return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|