#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/evaluator.py
# Project: /home/richard/projects/DeepOrchestration/utils
# Created Date: Wednesday, May 29th 2024, 4:59:38 pm
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Thu May 30 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###
import torch
import numpy as np
from collections import defaultdict
from tqdm import tqdm


class Evaluator:

    def __init__(self, net, test_loader, criterion, metric_func, device, cfg):
        self.net = net
        self.device = device
        self.criterion = criterion
        self.test_loader = test_loader
        self.cfg = cfg
        self.metrics_func = metric_func

    def run(self):

        self.net.eval()

        y_true_list = []
        y_pred_list = []

        loss_dict = defaultdict(list)

        with torch.no_grad():

            for step, batch_data in tqdm(enumerate(self.test_loader),
                                         desc="evaluating {} ".format("test")):

                X, y = batch_data
                X = X.to(self.device)
                y = y.to(self.device)

                if self.cfg.model.name == "baseline":
                    pred_logit = self.net(X)
                    loss = self.criterion(pred_logit, y)

                elif self.cfg.model.name == "deepor":
                    pred_logit, selected_features, reconstruct_repr, latent_ata, latent_dgfs = self.net(
                        X)
                    loss = self.criterion(X, pred_logit, reconstruct_repr,
                                          latent_ata, latent_dgfs, y)

                loss_dict["loss"].append(loss["total_loss"].item())

                y_true_list.extend(y.cpu().numpy().tolist())
                y_pred_list.extend(
                    pred_logit.squeeze(-1).cpu().numpy().tolist())

        metrics = self.metrics_func(np.array(y_pred_list),
                                    np.array(y_true_list))

        return metrics
