File size: 3,425 Bytes
13531f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
from dataclasses import dataclass
from pathlib import Path
from typing import Any

import torch
from PIL import Image
from refiners.foundationals.clip.concepts import ConceptExtender
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
    MultiUpscaler,
    UpscalerCheckpoints,
)

from esrgan_model import UpscalerESRGAN


@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
    esrgan: Path | None = None


class ESRGANUpscaler(MultiUpscaler):
    def __init__(
        self,
        checkpoints: ESRGANUpscalerCheckpoints,
        device: torch.device,
        dtype: torch.dtype,
    ) -> None:
        super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
        self.esrgan = self.load_esrgan(checkpoints.esrgan)

    def to(self, device: torch.device, dtype: torch.dtype):
        self.esrgan.to(device=device, dtype=dtype)
        self.sd = self.sd.to(device=device, dtype=dtype)
        self.device = device
        self.dtype = dtype

    def load_esrgan(self, path: Path | None) -> UpscalerESRGAN | None:
        if path is None:
            return None
        return UpscalerESRGAN(path, device=self.device, dtype=self.dtype)

    def load_negative_embedding(self, path: Path | None, key: str | None) -> str:
        if path is None:
            return ""

        embeddings: torch.Tensor | dict[str, Any] = torch.load(  # type: ignore
            path, weights_only=True, map_location=self.device
        )

        if isinstance(embeddings, dict):
            assert (
                key is not None
            ), "Key must be provided to access the negative embedding."
            key_sequence = key.split(".")
            for key in key_sequence:
                assert (
                    key in embeddings
                ), f"Key {key} not found in the negative embedding dictionary. Available keys: {list(embeddings.keys())}"
                embeddings = embeddings[key]

        assert isinstance(
            embeddings, torch.Tensor
        ), f"The negative embedding must be a tensor, found {type(embeddings)}."
        assert (
            embeddings.ndim == 2
        ), f"The negative embedding must be a 2D tensor, found {embeddings.ndim}D tensor."

        extender = ConceptExtender(self.sd.clip_text_encoder)
        negative_embedding_token = ", "
        for i, embedding in enumerate(embeddings):
            embedding = embedding.to(device=self.device, dtype=self.dtype)
            extender.add_concept(token=f"<{i}>", embedding=embedding)
            negative_embedding_token += f"<{i}> "
        extender.inject()

        return negative_embedding_token

    def pre_upscale(
        self,
        image: Image.Image,
        upscale_factor: float,
        use_esrgan: bool = True,
        use_esrgan_tiling: bool = True,
        **_: Any,
    ) -> Image.Image:
        if self.esrgan is None or not use_esrgan:
            return super().pre_upscale(image=image, upscale_factor=upscale_factor)

        width, height = image.size

        if use_esrgan_tiling:
            image = self.esrgan.upscale_with_tiling(image)
        else:
            image = self.esrgan.upscale_without_tiling(image)

        return image.resize(
            size=(
                int(width * upscale_factor),
                int(height * upscale_factor),
            ),
            resample=Image.LANCZOS,
        )