简介

pytorch的schedular,作用是:n个epoch没有精度提升后,调整学习率。

参数看文档去,注意这个schedular要在eval时step,这里主要记录一下“判断没有精度提升”的具体行为

可能的理解

我脑海里对“n个epoch没有精度提升”有两种理解

  1. 连续n个epoch发生:当前精度比前一个epoch精度差
  2. 在n个epoch中:最高精度出现在n个epoch之前

结论先行,理解2是对的

ReduceLROnPlateau源码

class ReduceLROnPlateau:
    ...
    def step(self, metrics, epoch=None):
        current = float(metrics)  # 这里是每次step时传进来的当前epoch精度
        if epoch is None:  # 不要自己传epoch,schedular会自己记录,否则会被警告
            epoch = self.last_epoch + 1
        else:
            warnings.warn(EPOCH_DEPRECATION_WARNING, UserWarning)
        self.last_epoch = epoch

        if self.is_better(current, self.best):  # 到这里就发现了,schedular记录了一个self.best,而不是self.prev
            self.best = current
            self.num_bad_epochs = 0
        else:
            self.num_bad_epochs += 1

        if self.in_cooldown:
            self.cooldown_counter -= 1
            self.num_bad_epochs = 0  # ignore any bad epochs in cooldown

        if self.num_bad_epochs > self.patience:
            self._reduce_lr(epoch)
            self.cooldown_counter = self.cooldown
            self.num_bad_epochs = 0

        self._last_lr = [group['lr'] for group in self.optimizer.param_groups]
如果觉得我的文章对你有用,请随意赞赏