#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/main.py
# Project: /home/richard/projects/DeepOrchestration
# Created Date: Wednesday, May 29th 2024, 5:05:59 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Mon Jun 10 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###
import os
import sys

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import hydra
from omegaconf import DictConfig
import torch
import gc
from DeepOrchestration.utils.componants import deepor_componants
from DeepOrchestration.utils.nni_utils import update_cfg
import mlflow


@hydra.main(config_path="config", config_name="train.yaml")
def main(cfg: DictConfig) -> None:

    orig_cwd = hydra.utils.get_original_cwd()
    cfg.orig_cwd = orig_cwd
    net, my_trainer, evaluetor, logger = deepor_componants(cfg, orig_cwd)

    if cfg.other.debug:
        cfg.train.num_epoch = 2

    if cfg.logger.mlflow:
        # log hyper-parameters
        for p, v in cfg.data.items():
            logger.log_param(p, v)

        for p, v in cfg.model.deepor.items():
            logger.log_param(p, v)

        for p, v in cfg.train.items():
            logger.log_param(p, v)

    if cfg.mode.nni:
        # use nni params
        cfg = update_cfg(cfg)

    # training
    logger.std_print("training start......")
    my_trainer.run()

    # # evaluation
    # best_model_path = my_trainer.best_model_path

    # if best_model_path:
    #     logger.std_print("best checkpoint {}......".format(best_model_path))
    #     evaluetor.net = load_weights(net, best_model_path, get_device(cfg))
    #     logger.std_print("running evaluation......")
    #     metrics = evaluetor.run()
    #     for key, value in metrics.items():
    #         logger.std_print("test_{}: {}".format(key, value))

    torch.cuda.empty_cache()
    gc.collect()
    print("Done")


if __name__ == "__main__":

    main()
