import torch.nn as nn from transformers import PretrainedConfig class cceVAEConfig(PretrainedConfig): model_type = "cceVAE" def __init__( self, d=2, input_size=(1, 256, 256), z_dim=1024, fmap_sizes=(16, 64, 256, 1024), to_1x1=True, conv_params=None, tconv_params=None, normalization_op=None, normalization_params=None, activation_op="prelu", activation_params=None, block_op=None, block_params=None, **kwargs): self.d = d self.input_size = input_size self.z_dim = z_dim self.fmap_sizes = fmap_sizes self.to_1x1 = to_1x1 self.conv_params = conv_params self.tconv_params = tconv_params self.normalization_op = normalization_op self.normalization_params = normalization_params self.activation_op = activation_op self.activation_params = activation_params self.block_op = block_op self.block_params = block_params super().__init__(**kwargs)