import torch import random import torchvision.transforms as T import numpy as np class RandAug: """Randomly chosen image augmentations.""" def __init__(self): # Augmentation options self.trans = ['identity', 'color', 'sharpness', 'blur'] def __call__(self, img): self.choice = random.choices(self.trans, weights=(25, 25, 25, 25))[0] if self.choice == 'identity': return img elif self.choice == 'color': rand_brightness = random.uniform(0, 0.3) rand_hue = random.uniform(0, 0.5) rand_contrast = random.uniform(0, 0.5) rand_saturation = random.uniform(0, 0.5) trans = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue) img = trans(img) elif self.choice=='sharpness': sharpness = 1+(np.random.exponential()/2) trans = T.RandomAdjustSharpness(sharpness, p=1) img = trans(img) elif self.choice=='blur': kernel = random.choice([1,3,5]) trans = T.GaussianBlur(kernel, sigma=(0.1, 2.0)) img = trans(img) return img class RandRotate: """Randomly chosen image augmentations.""" def __init__(self, low = 0, high = 180): # Augmentation options self.rotation = torch.randint(low=low, high=high, size=(1,)).item() self.trans = ['identity', 'rotation'] def __call__(self, img, mask): self.choice = random.choices(self.trans, weights=(50, 50))[0] if self.choice == 'identity': return img, mask elif self.choice == 'rotation': rotated_img = T.functional.rotate(img=img, angle=self.rotation, expand=False) rotated_mask = T.functional.rotate(img=mask, angle=self.rotation, expand=False) return rotated_img, rotated_mask