deepfake / training /tools /schedulers.py
thecho7's picture
LFS dat
c426e13
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
class LRStepScheduler(_LRScheduler):
def __init__(self, optimizer, steps, last_epoch=-1):
self.lr_steps = steps
super().__init__(optimizer, last_epoch)
def get_lr(self):
pos = max(bisect_right([x for x, y in self.lr_steps], self.last_epoch) - 1, 0)
return [self.lr_steps[pos][1] if self.lr_steps[pos][0] <= self.last_epoch else base_lr for base_lr in self.base_lrs]
class PolyLR(_LRScheduler):
"""Sets the learning rate of each parameter group according to poly learning rate policy
"""
def __init__(self, optimizer, max_iter=90000, power=0.9, last_epoch=-1):
self.max_iter = max_iter
self.power = power
super(PolyLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
self.last_epoch = (self.last_epoch + 1) % self.max_iter
return [base_lr * ((1 - float(self.last_epoch) / self.max_iter) ** (self.power)) for base_lr in self.base_lrs]
class ExponentialLRScheduler(_LRScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
When last_epoch=-1, sets initial lr as lr.
Args:
optimizer (Optimizer): Wrapped optimizer.
gamma (float): Multiplicative factor of learning rate decay.
last_epoch (int): The index of last epoch. Default: -1.
"""
def __init__(self, optimizer, gamma, last_epoch=-1):
self.gamma = gamma
super(ExponentialLRScheduler, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch <= 0:
return self.base_lrs
return [base_lr * self.gamma**self.last_epoch for base_lr in self.base_lrs]