#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/syncorepeppi/logger/mlflow_logger.py
# Project: /home/richard/projects/DeepOrchestration/logger
# Created Date: Saturday, May 6th 2023, 4:20:58 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Wed May 29 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2023 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2023 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###
import os
import logging
import shutil

import mlflow
import numpy as np
from ..utils.utils import is_parallel
from .std_logger import StdLogger
from pathlib import Path


class MlflowLog():

    def __init__(self) -> None:
        mlflow.set_tracking_uri(os.environ["MLFLOW_TRACKING_URI"])
        mlflow.set_experiment(os.environ["MLFLOW_EXPERIMENT_NAME"])

        self.std_logger = StdLogger("", level=logging.INFO)

    def log_param(self, k, v):
        try:
            mlflow.log_param(k, v)
        except Exception as e:
            self.std_logger.std_print(e)

        self.std_logger.log_param(k, v)

    def log_params(self, para_dict):
        """ Log parameters
        """
        for p, v in para_dict.items():

            try:
                mlflow.log_param(p, v)
            except Exception as e:
                self.std_logger.std_print(e)

        self.std_logger.log_params(para_dict)

    def log_metric(self, key, value, step):
        """ Log metrics
        """

        if isinstance(
                value,
            (float, np.float64, np.float32, int, np.int32, np.int64)):

            try:
                mlflow.log_metric(key, value, step=step)
            except Exception as e:
                self.std_logger.std_print(e)

            self.std_logger.log_metric(key, value, step)

        elif isinstance(value, str):
            try:
                mlflow.log_text(value, key)
            except Exception as e:
                self.std_logger.std_print(e)

            self.std_logger.log_metric(key, value, step)

    def save_model(self, model, save_path: Path, code_path: list = []):
        """ Save model
        """
        if save_path.exists():
            shutil.rmtree(save_path)

        mlflow.pytorch.save_model(
            (model.module if is_parallel(model) else model),
            save_path,
            code_paths=code_path)

    def std_print(self, str_):
        """ Print to stdout
        """
        self.std_logger.std_print(str_)
