简介
pytorch的schedular,作用是:n个epoch没有精度提升后,调整学习率。
参数看文档去,注意这个schedular要在eval时step,这里主要记录一下“判断没有精度提升”的具体行为
可能的理解
我脑海里对“n个epoch没有精度提升”有两种理解
- 连续n个epoch发生:当前精度比前一个epoch精度差
- 在n个epoch中:最高精度出现在n个epoch之前
结论先行,理解2是对的
ReduceLROnPlateau源码
class ReduceLROnPlateau:
...
def step(self, metrics, epoch=None):
current = float(metrics) # 这里是每次step时传进来的当前epoch精度
if epoch is None: # 不要自己传epoch,schedular会自己记录,否则会被警告
epoch = self.last_epoch + 1
else:
warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
self.last_epoch = epoch
if self.is_better(current, self.best): # 到这里就发现了,schedular记录了一个self.best,而不是self.prev
self.best = current
self.num_bad_epochs = 0
else:
self.num_bad_epochs += 1
if self.in_cooldown:
self.cooldown_counter -= 1
self.num_bad_epochs = 0 # ignore any bad epochs in cooldown
if self.num_bad_epochs > self.patience:
self._reduce_lr(epoch)
self.cooldown_counter = self.cooldown
self.num_bad_epochs = 0
self._last_lr = [group['lr'] for group in self.optimizer.param_groups]