from bisect import bisect_right from torch.optim.lr_scheduler import _LRScheduler class LRStepScheduler(_LRScheduler): def __init__(self, optimizer, steps, last_epoch=-1): self.lr_steps = steps super().__init__(optimizer, last_epoch) def get_lr(self): pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0) return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs] class PolyLR(_LRScheduler): """Sets the learning rate of each parameter group according to poly learning rate policy """ def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1): self.max_iter = max_iter self.power = power super(PolyLR, self).__init__(optimizer, last_epoch) def get_lr(self): self.last_epoch = (self.last_epoch + 1) % self.max_iter return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs] class ExponentialLRScheduler(_LRScheduler): """Decays the learning rate of each parameter group by gamma every epoch. When last_epoch=-1, sets initial lr as lr. Args: optimizer (Optimizer): Wrapped optimizer. gamma (float): Multiplicative factor of learning rate decay. last_epoch (int): The index of last epoch. Default: -1. """ def __init__(self, optimizer, gamma, last_epoch=-1): self.gamma = gamma super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch) def get_lr(self): if self.last_epoch <= 0: return self.base_lrs return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]