soumickmj's picture
Upload UNetMSS3D
debfac6 verified
raw
history blame contribute delete
799 Bytes
from transformers import PretrainedConfig
from typing import List
class UNet3DConfig(PretrainedConfig):
model_type = "UNet"
def __init__(
self,
in_ch=1,
out_ch=1,
init_features=64,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.init_features = init_features
super().__init__(**kwargs)
class UNetMSS3DConfig(PretrainedConfig):
model_type = "UNetMSS"
def __init__(
self,
in_ch=1,
out_ch=1,
output_dir=None,
init_features=64,
**kwargs):
self.in_ch = in_ch
self.out_ch = out_ch
self.output_dir = output_dir
self.init_features = init_features
super().__init__(**kwargs)