#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /root/CAMP/std_logger.py
# Project: /home/richard/projects/pep_interaction/logger
# Created Date: Saturday, July 30th 2022, 3:51:24 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Thu Feb 22 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2022 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2022 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import logging
import sys
import torch
import numpy as np
from ..utils.utils import is_parallel
from pathlib import Path


class StdLogger():

    def __init__(self, file_path="", stream=False, level=logging.INFO):
        formatter = logging.Formatter(fmt='%(asctime)s %(message)s',
                                      datefmt="%H:%M:%S")
        self.logger = logging.getLogger(__file__)
        self.logger.setLevel(level)
        self.file_path = file_path

        if file_path:
            file_hander = logging.FileHandler(file_path)
            file_hander.setFormatter(formatter)
            self.logger.addHandler(file_hander)

        if stream:
            stream_handler = logging.StreamHandler(stream=sys.stderr)
            stream_handler.setFormatter(formatter)
            self.logger.addHandler(stream_handler)

    def log_param(self, k, v):

        self.logger.info("{}: {}".format(k, v))

    def log_params(self, para_dict):
        """ Log parameters
        """
        for p, v in para_dict.items():
            self.logger.info("{}: {}".format(p, v))

    def log_metric(self, key, value, step):
        """ Log metrics
        """
        if isinstance(
                value,
            (float, np.float64, np.float32, int, np.int32, np.int64)):
            self.logger.info("step: {} | {}: {:.4}".format(
                step, key, float(value)))
        elif isinstance(value, str):
            self.logger.info("step: {} | {}: \n{}".format(step, key, value))

    def save_model(self, model, save_path: Path, code_path: list = []):
        """ Save model
        """

        torch.save((model.module if is_parallel(model) else model), save_path)

    def std_print(self, str_):
        """ Print to stdout
        """
        self.logger.info(str_)


# cfg = OmegaConf.load(
#     os.path.join(os.path.dirname(os.path.abspath(__file__)), 'train_conf.yaml'))
# log_dir = os.path.join(cfg.logger.log_dir, str(int(time())))
# os.makedirs(log_dir)
# file_path = os.path.join(log_dir, "train.log")
