import pandas as pd
import numpy as np

import torch.nn as nn
from matplotlib import pyplot as plt
import seaborn as sn
import pickle as pkl
import traceback
import os
import cv2

from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_recall_fscore_support
from sklearn.decomposition import PCA
from pathlib import Path
from matplotlib import pyplot as plt
import cv2
import io
import torch


def get_device(cfg):
    device = torch.device(
        "cuda:{}".format(cfg.train.device_ids[0]) if torch.cuda.is_available()
        and len(cfg.train.device_ids) > 0 else "cpu")

    return device


def is_parallel(model):
    return type(model) in (nn.parallel.DataParallel,
                           nn.parallel.DistributedDataParallel)


def save_cm(array, save_name):
    df_cm = pd.DataFrame(array)

    plt.figure(figsize=(10, 10))
    svm = sn.heatmap(df_cm,
                     annot=True,
                     cmap='coolwarm',
                     linecolor='white',
                     linewidths=1)
    plt.savefig(save_name, dpi=400)


def save_np(array, save_name):
    if os.path.basename(save_name).split(".")[0].endswith("cm"):
        save_cm(array, save_name)
    else:
        cv2.imwrite(save_name, array)


def read_data(key, feature_path, label_path):
    """read data file
        """
    feature = pd.read_hdf(feature_path, key=key)
    with open(label_path, "rb") as f:
        label = pkl.load(f)[key]

    assert feature.shape[0] == len(label)
    return feature, np.array(label)


def normalize(pvals):
    max_ = np.max(pvals)
    min_ = np.min(pvals)
    return list(map(lambda x: (x - min_) / (max_ - min_), pvals))


def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    return img


def visualize_emb(emb, targets):

    if isinstance(emb, np.ndarray):
        X = emb
    else:
        X = emb.cpu().numpy()

    if X.shape[-1] > 2:
        pca = PCA(n_components=2)
        components = pca.fit_transform(X)
    else:
        components = X
    neg_x = []
    neg_y = []
    pos_x = []
    pos_y = []
    for idx, target in enumerate(targets):
        if target == 0:
            neg_x.append(components[idx, 0])
            neg_y.append(components[idx, 1])
        else:
            pos_x.append(components[idx, 0])
            pos_y.append(components[idx, 1])

    fig = plt.figure(figsize=(8, 8))
    ax = fig.add_subplot(111)
    ax.scatter(neg_x, neg_y, color="red", label="Negative")
    ax.scatter(pos_x, pos_y, color="blue", label="Positive")
    plt.close(fig)

    return get_img_from_fig(fig)


def load_weights(model, best_model_path, device):

    if isinstance(best_model_path, str):
        best_model_path = Path(best_model_path)

    if os.path.isdir(best_model_path):
        best_model_path = best_model_path / "data/model.pth"
    if is_parallel(model):
        model = model.module

    model_dict = model.state_dict()

    best_state_dict = {
        k.replace("module.", ""): v
        for (k, v) in list(
            torch.load(best_model_path,
                       map_location="cpu").state_dict().items())
        # torch.load(best_model_path,map_location="cpu"
        #            ).items())
    }
    for k, v in best_state_dict.items():
        if k in model_dict:
            print("updating parameters for: {}".format(k))
            model_dict[k] = v

    model.load_state_dict(model_dict)
    model.to(torch.cuda.current_device())

    return model
