#!/usr/bin/env python3
# -*- coding:utf-8 -*-
###
# File: /home/richard/projects/DeepOrchestration/models/deepor.py
# Project: /home/richard/projects/DeepOrchestration/models
# Created Date: Thursday, May 30th 2024, 12:01:19 am
# Author: Ruochi Zhang
# Email: zrc720@gmail.com
# -----
# Last Modified: Sat Jun 08 2024
# Modified By: Ruochi Zhang
# -----
# Copyright (c) 2024 Bodkin World Domination Enterprises
#
# MIT License
#
# Copyright (c) 2024 Ruochi Zhang
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# -----
###

import torch
import torch.nn as nn
import torch.nn.functional as F
from .dgfs import DGFS
from .ata import ATA
from .rml import RML


class DeepOr(nn.Module):

    def __init__(self,
                 input_dim=76,
                 time_steps=48,
                 latent_dim=128,
                 tranformer_layers=2,
                 decoder_nhead=2,
                 dim_feedforward=128,
                 setpoint=0.1,
                 temperature=1.0,
                 Kp=0.1,
                 Ki=0.001,
                 Kd=0.01,
                 min_temp=0.2,
                 max_temp=1.0,
                 head_hidden=64):

        super(DeepOr, self).__init__()

        self.dgfs = DGFS(input_dim,
                         time_steps,
                         latent_dim,
                         temperature=temperature,
                         setpoint=setpoint,
                         Kp=Kp,
                         Ki=Ki,
                         Kd=Kd,
                         min_temp=min_temp,
                         max_temp=max_temp)

        self.ata = ATA(input_dim=input_dim,
                       time_steps=time_steps,
                       latent_dim=latent_dim,
                       decoder_nhead=decoder_nhead,
                       num_encoder_layers=tranformer_layers,
                       num_decoder_layers=tranformer_layers,
                       dim_feedforward=dim_feedforward)

        self.rml = RML(latent_dim)

        self.head = nn.Sequential(nn.Linear(latent_dim * 2, head_hidden),
                                  nn.ReLU(), nn.Linear(head_hidden, 1),
                                  nn.Sigmoid())

    def forward(self, x):

        latent_dgfs, sparse_weights = self.dgfs(x)
        reconstruct_repr, latent_ata = self.ata(x, sparse_weights)
        combinded = self.rml(latent_dgfs, latent_ata)
        logits = self.head(combinded)
        return logits, sparse_weights, reconstruct_repr, latent_ata, latent_dgfs

    def predict(self, x, n_features=10):
        latent_dgfs, sparse_weights = self.dgfs.predict(x, n_features)
        reconstruct_repr, latent_ata = self.ata(x, sparse_weights)
        combinded = self.rml(latent_dgfs, latent_ata)
        logits = self.head(combinded)
        return logits, sparse_weights, reconstruct_repr, latent_ata, latent_dgfs


if __name__ == "__main__":
    input_dim = 76
    time_steps = 48
    latent_dim = 128
    batch_size = 32

    x = torch.randn(batch_size, time_steps, input_dim)
    model = DeepOr(input_dim, time_steps, latent_dim)

    logits, selected_features, reconstruct_repr, latent_ata, latent_dgfs = model(
        x)
    print(logits.shape, selected_features.shape, reconstruct_repr.shape,
          latent_ata.shape, latent_dgfs.shape)
