File size: 2,004 Bytes
f3b2c5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from enum import Enum
import torch

from model_classes import Model200M, Model5M, SyntheticV2
from model_transforms import transform_200M, transform_5M, transform_synthetic 

class ModelType(str, Enum):
    MIDJOURNEY_200M = "midjourney_200M"
    DIFFUSIONS_200M = "diffusions_200M"
    MIDJOURNEY_5M = "midjourney_5M"
    DIFFUSIONS_5M = "diffusions_5M"
    SYNTHETIC_DETECTOR_V2 = "synthetic_detector_v2"

    def __str__(self):
        return str(self.value)

    @staticmethod
    def get_list():
        return [model_type.value for model_type in ModelType]

def load_model(value: ModelType):
    model = type_to_class[value]
    path = type_to_path[value]
    ckpt = torch.load(path, map_location=torch.device('cpu'))
    model.load_state_dict(ckpt)
    model.eval()
    return model

type_to_class = {
    ModelType.MIDJOURNEY_200M : Model200M(),
    ModelType.DIFFUSIONS_200M : Model200M(),
    ModelType.MIDJOURNEY_5M : Model5M(),
    ModelType.DIFFUSIONS_5M : Model5M(),
    ModelType.SYNTHETIC_DETECTOR_V2 : SyntheticV2(),
}

type_to_path = {
    ModelType.MIDJOURNEY_200M : 'models/midjourney200M.pt',
    ModelType.DIFFUSIONS_200M : 'models/diffusions200M.pt',
    ModelType.MIDJOURNEY_5M : 'models/midjourney5M.pt',
    ModelType.DIFFUSIONS_5M : 'models/diffusions5M.pt',
    ModelType.SYNTHETIC_DETECTOR_V2 : 'models/synthetic_detector_v2.pt',
}

type_to_loaded_model = {
    ModelType.MIDJOURNEY_200M: load_model(ModelType.MIDJOURNEY_200M),
    ModelType.DIFFUSIONS_200M: load_model(ModelType.DIFFUSIONS_200M),
    ModelType.MIDJOURNEY_5M: load_model(ModelType.MIDJOURNEY_5M),
    ModelType.DIFFUSIONS_5M: load_model(ModelType.DIFFUSIONS_5M),
    ModelType.SYNTHETIC_DETECTOR_V2: load_model(ModelType.SYNTHETIC_DETECTOR_V2)
}

type_to_transforms = {
    ModelType.MIDJOURNEY_200M: transform_200M,
    ModelType.DIFFUSIONS_200M: transform_200M,
    ModelType.MIDJOURNEY_5M: transform_5M,
    ModelType.DIFFUSIONS_5M: transform_5M,
    ModelType.SYNTHETIC_DETECTOR_V2: transform_synthetic
}