import os
import math
import ctypes
import json
import matplotlib.pyplot as plt

from pymatgen.core import Structure, Composition, Molecule
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.io.vasp.sets import MPRelaxSet
from pymatgen.io.vasp.inputs import Poscar
from xml.etree.ElementTree import ParseError
from matplotlib.patches import FancyArrowPatch
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.patheffects as pe
from pymatgen.io.vasp.outputs import Vasprun
from pymatgen.io.vasp.outputs import Oszicar

#
KB_TO_GIGAPASC = 0.1
KB_TO_MEGAPASC = 100

plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 12
plt.xlabel('X-axis', fontsize=14)
plt.ylabel('Y-axis', fontsize=14)
plt.yticks(fontsize=12)
plt.xlim(-5.51, 5.51)

def vr_to_dict(filename="vasprun.xml",
               output="results.json"):  # Method to save vaspruns in a dict format for faster future loading
    completed = True
    try:
        vasprun = Vasprun(filename)  # Open file

        if not vasprun.converged:  # Catch ionic
            nsw = vasprun.incar["NSW"]
            completed = False

        max_e = vasprun.incar["NELM"]  # Catch any electronic limit in all ionic steps
        i_s = [step for step in vasprun.ionic_steps]
        len_i = [len(v["electronic_steps"]) for v in i_s]
        ids = [idx for idx, val in enumerate(len_i) if val == max_e]
        if len(ids):
            completed = False

    except ParseError:
        completed = False

    if completed:
        vr = vasprun
        st = vr.final_structure
        e = vr.final_energy
        c = ComputedStructureEntry(structure=st, energy=e)
        d = c.as_dict()

        with open(output, "w") as f:
            json.dump(d, f)


def get_attribute_values(list_of_objects, attribute):
    return [getattr(obj, attribute) for obj in list_of_objects]


def plot_variable_vs_pstress(results, variable_name, y_axis, ticks=None):  # Style _1
    plt.xticks(ticks=ticks, fontsize=12)
    plt.ylabel(y_axis, fontsize=14, labelpad=8)
    for result in results:  # Unit = GPa
        plt.scatter(x=result.pstress / 10, y=getattr(result, variable_name), c=result.color, s=10, marker="s",
                    alpha=0.7)
        plt.show()


class Result:
    def __init__(self, metal, pstress, a, b, c, beta, d_1, d_2, vol, i_oo4, step, energy):
        self.metal = metal
        self.pstress = pstress
        self.a = a
        self.b = b
        self.c = c
        self.beta = beta
        self.d_1 = d_1
        self.d_2 = d_2
        self.vol = vol
        self.i_oo4 = i_oo4
        self.step = step
        self.energy = energy

    def get_c_true(self):
        self.c_true = self.c * math.sin(self.beta / 57.295) * 3

    def get_color(self):
        if self.metal == "Cobalt":
            self.color = "Blue"
        elif self.metal == "Nickel":
            self.color = "Grey"
        elif self.metal == "Manganese":
            self.color = "Purple"


co_res = []
mn_res = []
ni_res = []
head = "data/max_cells_response/"
for metal_folder in ["cobalt_half_lithiation", "manganese_half_lithiation", "nickel_half_lithiation"]:
    head_dir = f"{head}/{metal_folder}/"
    for step in ["A", "B", "C", "D", "E"]:
        for pstress in range(-50, 50, 2):
            # Handle negatives in the dumbest way
            if pstress < 0:
                pstress_string = f"{step}_{abs(pstress)}"
            else:
                pstress_string = f"{step}{pstress}"
            # Grab contcar as structure
            try:
                structure = Structure.from_file(filename=f"{head_dir}/{step}/{pstress_string}/CONTCAR")
                o = Oszicar(filename=f"{head_dir}/{step}/{pstress_string}/OSZICAR")
            except:
                continue

            if metal_folder == "cobalt_half_lithiation":
                st = "Cobalt"

                result = Result(metal=st, pstress=pstress,
                                a=structure.lattice.a, b=structure.lattice.b, c=structure.lattice.c,
                                beta=structure.lattice.beta, vol=structure.lattice.volume, energy=float(o.final_energy), step=step,
                                d_1=None,
                                d_2=None, i_oo4=None)
                result.get_color()
                co_res.append(result)

            elif metal_folder == "manganese_half_lithiation":
                st = "Manganese"
                result = Result(metal=st, pstress=pstress,
                                a=structure.lattice.a, b=structure.lattice.b, c=structure.lattice.c,
                                beta=structure.lattice.beta, vol=structure.lattice.volume, energy=float(o.final_energy), step=step,
                                d_1=None,
                                d_2=None, i_oo4=None)
                result.get_color()
                mn_res.append(result)
            elif metal_folder == "nickel_half_lithiation":
                st = "Nickel"
                result = Result(metal=st, pstress=pstress,
                                a=structure.lattice.a, b=structure.lattice.b, c=structure.lattice.c,
                                beta=structure.lattice.beta, vol=structure.lattice.volume, energy=float(o.final_energy), step=step,
                                d_1=None,
                                d_2=None, i_oo4=None)
                result.get_color()
                ni_res.append(result)


var = "energy"
plot_variable_vs_pstress(co_res[0:50], var, "c axis / Å")

import matplotlib.pyplot as plt

data = [
    [-50, 0, 0.097688, 0.014406, 0.038302, 0.330949],
    [-48, 0, 0.092327, 0.039124, 0.094874, 0.329381],
    [-46, 0, 0.100423, 0.095833, 0.118475, 0.330233],
    [-44, 0, 0.1211306, 0.079584, 0.17004, 0.338621]
]

x = [row[0] for row in data]
y1 = [row[1] for row in data]
y2 = [row[2] for row in data]
y3 = [row[3] for row in data]
y4 = [row[4] for row in data]
y5 = [row[5] for row in data]

plt.plot(x, y1, color='red', label='Undefected Structure (A)')
plt.plot(x, y2, color='green', label='Tetrahedral Lithium site Energy (B)')
plt.plot(x, y3, color='blue', label='Octahedral Lithium site Energy (C)')
plt.plot(x, y4, color='orange', label='Lithium in dumbbell site energy (D)')
plt.plot(x, y5, color='purple', label='Dumbbell structure (E)')

plt.xlabel('Pressures')
plt.ylabel('Curves')
plt.title('Data Plot')
plt.legend()
plt.show()