Table Of Contents
Table Of Contents

Source code for gluoncv.utils.lr_scheduler

"""Popular Learning Rate Schedulers"""
# pylint: disable=missing-docstring
from __future__ import division

import warnings
from math import pi, cos
from mxnet import lr_scheduler

[docs]class LRSequential(lr_scheduler.LRScheduler): r"""Compose Learning Rate Schedulers Parameters ---------- schedulers: list list of LRScheduler objects """ def __init__(self, schedulers): super(LRSequential, self).__init__() assert(len(schedulers) > 0) self.update_sep = [] self.count = 0 self.learning_rate = 0 self.schedulers = [] for lr in schedulers: self.add(lr) def add(self, scheduler): assert(isinstance(scheduler, LRScheduler)) scheduler.offset = self.count self.count += scheduler.niters self.update_sep.append(self.count) self.schedulers.append(scheduler) def __call__(self, num_update): self.update(num_update) return self.learning_rate def update(self, num_update): num_update = min(num_update, self.count - 1) ind = len(self.schedulers) - 1 for i, sep in enumerate(self.update_sep): if sep > num_update: ind = i break lr = self.schedulers[ind] lr.update(num_update) self.learning_rate = lr.learning_rate
[docs]class LRScheduler(lr_scheduler.LRScheduler): r"""Learning Rate Scheduler Parameters ---------- mode : str Modes for learning rate scheduler. Currently it supports 'constant', 'step', 'linear', 'poly' and 'cosine'. base_lr : float Base learning rate, i.e. the starting learning rate. target_lr : float Target learning rate, i.e. the ending learning rate. With constant mode target_lr is ignored. niters : int Number of iterations to be scheduled. nepochs : int Number of epochs to be scheduled. iters_per_epoch : int Number of iterations in each epoch. offset : int Number of iterations before this scheduler. power : float Power parameter of poly scheduler. step_iter : list A list of iterations to decay the learning rate. step_epoch : list A list of epochs to decay the learning rate. step_factor : float Learning rate decay factor. """ def __init__(self, mode, base_lr=0.1, target_lr=0, niters=0, nepochs=0, iters_per_epoch=0, offset=0, power=2, step_iter=None, step_epoch=None, step_factor=0.1, baselr=None, targetlr=None): super(LRScheduler, self).__init__() assert(mode in ['constant', 'step', 'linear', 'poly', 'cosine']) self.mode = mode if mode == 'step': assert(step_iter is not None or step_epoch is not None) if baselr is not None: warnings.warn("baselr is deprecated. Please use base_lr.") if base_lr == 0.1: base_lr = baselr self.base_lr = base_lr if targetlr is not None: warnings.warn("targetlr is deprecated. Please use target_lr.") if target_lr == 0: target_lr = targetlr self.target_lr = target_lr if self.mode == 'constant': self.target_lr = self.base_lr self.niters = niters self.step = step_iter epoch_iters = nepochs * iters_per_epoch if epoch_iters > 0: self.niters = epoch_iters if step_epoch is not None: self.step = [s*iters_per_epoch for s in step_epoch] self.offset = offset self.power = power self.step_factor = step_factor def __call__(self, num_update): self.update(num_update) return self.learning_rate def update(self, num_update): N = self.niters - 1 T = num_update - self.offset T = min(max(0, T), N) if self.mode == 'constant': factor = 0 elif self.mode == 'linear': factor = 1 - T / N elif self.mode == 'poly': factor = pow(1 - T / N, self.power) elif self.mode == 'cosine': factor = (1 + cos(pi * T / N)) / 2 elif self.mode == 'step': if self.step is not None: count = sum([1 for s in self.step if s <= T]) factor = pow(self.step_factor, count) else: factor = 1 else: raise NotImplementedError if self.mode == 'step': self.learning_rate = self.base_lr * factor else: self.learning_rate = self.target_lr + (self.base_lr - self.target_lr) * factor