File size: 1,924 Bytes
8713ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import random
import torchvision.transforms as T
import numpy as np

class RandAug:
    """Randomly chosen image augmentations."""

    def __init__(self):
        # Augmentation options
        self.trans = ['identity', 'color', 'sharpness', 'blur']

    def __call__(self, img):
        self.choice = random.choices(self.trans, weights=(25, 25, 25, 25))[0]

        if self.choice == 'identity':
            return img

        elif self.choice == 'color':
            rand_brightness = random.uniform(0, 0.3)
            rand_hue = random.uniform(0, 0.5)
            rand_contrast = random.uniform(0, 0.5)
            rand_saturation = random.uniform(0, 0.5)
            trans = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
            img = trans(img)

        elif self.choice=='sharpness':
            sharpness = 1+(np.random.exponential()/2)
            trans = T.RandomAdjustSharpness(sharpness, p=1)
            img = trans(img)

        elif self.choice=='blur':
            kernel = random.choice([1,3,5])
            trans = T.GaussianBlur(kernel, sigma=(0.1, 2.0))  
            img = trans(img)

        return img


class RandRotate:
    """Randomly chosen image augmentations."""

    def __init__(self, low = 0, high = 180):
        # Augmentation options
        self.rotation = torch.randint(low=low, high=high, size=(1,)).item()
        self.trans = ['identity', 'rotation']

    def __call__(self, img, mask):
        self.choice = random.choices(self.trans, weights=(50, 50))[0]

        if self.choice == 'identity':
            return img, mask

        elif self.choice == 'rotation':
            rotated_img = T.functional.rotate(img=img, angle=self.rotation, expand=False)
            rotated_mask = T.functional.rotate(img=mask, angle=self.rotation, expand=False)           
            return rotated_img, rotated_mask