• Angel 中的学习率Decay
    • 1. ConstantLearningRate
    • 2. StandardDecay
    • 3. CorrectionDecay
    • 4. WarmRestarts

    Angel 中的学习率Decay

    Angel参考TensorFlow实现了多种学习率Decay方案, 用户可以根据需要选用. 在描述具体Decay方案前, 先了解一下Angel中的Decay是怎样引入的, 在什么时候进行Decay.

    对于第一个问题, Decay是在GraphLearner类中引入的, 在初始化时有如下代码:

    1. val ssScheduler: StepSizeScheduler = StepSizeScheduler(SharedConf.getStepSizeScheduler, lr0)

    其中StepSizeScheduler是所有Decay的基类, 同名的object是所有Decay的工场, SharedConf.getStepSizeScheduler通过读取”ml.opt.decay.class.name”的值可获是指定的decay类型(默认是StandardDecay).

    第二个问题, Angel提供了两种方案:

    • 一个mini-batch一次Decay
    • 一个epoch一次Decay

    通过”ml.opt.decay.on.batch”参数进行控制, 当其为true时, 一个mini-batch一次Decay, 当其为flase(默认的方式)一个epoch一次Decay. 具体代码在GraphLearner类中的train方法与trainOneEpoch方法中.

    1. ConstantLearningRate

    这是最简单的Decay方式, 就是不Decay, 学习率在整个训练过程中不变

    配置样例:

    1. ml.opt.decay.class.name=ConstantLearningRate

    2. StandardDecay

    标准Decay方案, 公式如下:

    Angel中的学习率Decay - 图1

    配置样例:

    1. ml.opt.decay.class.name=StandardDecay
    2. ml.opt.decay.alpha=0.001

    3. CorrectionDecay

    修正Decay, 这种方案适合于Momentum, 它是专为Momentum设计的, 请不要用于Adam等其它优化器. 计算公式为:

    Angel中的学习率Decay - 图2

    第一部分就是StandardDecay, 它是正常的Decay, 延续二部分是修正项, 为Momentum设计, 它是运动量系数之和的倒数. 其中$\beta$必须与优化器中的momentum相等. 一般可设为0.9.

    这种Decay的使用有两个注意点:

    • 动量计算公式应为: velocity = momentum * velocity + gradient, Angel中Momentum的实南已是这种方式
    • 要求一个mini-batch一次Decay, 因为要与参数的update同步

    配置样例:

    1. ml.opt.decay.class.name=CorrectionDecay
    2. ml.opt.decay.alpha=0.001
    3. ml.opt.decay.beta=0.9

    4. WarmRestarts

    这是一种较为高级的Decay方案, 它是周期中Decay的代表. 标准计算公式如下:

    Angel中的学习率Decay - 图3}})

    对于标准计算公式, 我们做了如下改进.

    • Angel中的学习率Decay - 图4进行衰减
    • 遂步增大Angel中的学习率Decay - 图5

    配置样例:

    1. ml.opt.decay.class.name=WarmRestarts
    2. ml.opt.decay.alpha=0.001

    其中Angel中的学习率Decay - 图6通过ml.opt.decay.intervals设置, 具体如下:

    1. class WarmRestarts(var etaMax: Double, etaMin: Double, alpha: Double) extends StepSizeScheduler {
    2. var current: Double = 0
    3. var numRestart: Int = 0
    4. var interval: Int = SharedConf.get().getInt(MLConf.ML_OPT_DECAY_INTERVALS, 100)
    5. override def next(): Double = {
    6. current += 1
    7. val value = etaMin + 0.5 * (etaMax - etaMin) * (1 + math.cos(current / interval * math.Pi))
    8. if (current == interval) {
    9. current = 0
    10. interval *= 2
    11. numRestart += 1
    12. etaMax = etaMax / math.sqrt(1.0 + numRestart * alpha)
    13. }
    14. value
    15. }
    16. override def isIntervalBoundary: Boolean = {
    17. current == 0
    18. }
    19. }