#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /root/CAMP/scheduler.py
# Project: /home/richard/projects/syncorepeppi/utils
# Created Date: Sunday, August 14th 2022, 11:00:09 am
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Tue Dec 20 2022
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2022 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2022 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau, LambdaLR, CyclicLR
import torch

def lr_scheduler(cfg, optimizer):

    scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 1)

    if cfg.train.lr_scheduler.type == "cosine":
        scheduler = CosineAnnealingWarmRestarts(
            optimizer,
            T_0=cfg.train.lr_scheduler.cosine.T_0,
            T_mult=cfg.train.lr_scheduler.cosine.T_mult,
            eta_min=cfg.train.lr_scheduler.cosine.eta_min, verbose=False)
    elif cfg.train.lr_scheduler.type == "cycle":

        def func(x):
            return cfg.train.lr_scheduler.cycle.max_lr_decrease_factor**(x - 1)

        scheduler = CyclicLR(optimizer=optimizer,
                                base_lr=cfg.train.lr_scheduler.cycle.base_lr,
                                max_lr=cfg.train.lr_scheduler.cycle.max_lr,
                                step_size_up=cfg.train.lr_scheduler.cycle.step_size_up,
                                step_size_down=cfg.train.lr_scheduler.cycle.step_size_down,
                                scale_fn=func,
                                cycle_momentum = False,
                                gamma=cfg.train.lr_scheduler.cycle.gamma, verbose=False)


    elif cfg.train.lr_scheduler.type == "plateau":
        scheduler = ReduceLROnPlateau(optimizer,
                                      mode=cfg.train.lr_scheduler.plateau.mode,
                                      factor=cfg.train.lr_scheduler.plateau.factor,
                                      patience=cfg.train.lr_scheduler.plateau.patience,
                                      threshold=0.0001,
                                      threshold_mode='rel',
                                      cooldown=0,
                                      min_lr=cfg.train.lr_scheduler.plateau.min_lr,
                                      eps=1e-08,
                                      verbose=False)

    return scheduler