from pymatgen.core import Structure
from pymatgen.io.vasp import Oszicar
import matplotlib.pyplot as plt
from collections import defaultdict
from matplotlib.patches import Rectangle


metals = ["cobalt", "nickel", "manganese"]
space_groups = ["fd", "o2", "o3", "p2"]
methods = ["noU", "no_IVDW", "VDW_1", "VDW_2"]
folder_names = ["no", "quarter", "third", "half", "twothird", "threequarter", "full"]
concs = [0, 0.25, 1/3, 0.5, 2/3, 0.75, 1.0]

plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 12
plt.xlabel('X-axis', fontsize=14)
plt.ylabel('Y-axis', fontsize=14)
plt.yticks(fontsize=12)
#plt.xlim(-5.51, 5.51) # define the xlim


def draw_rectangles(width=0.04, fill=True):
    for start in [0.0, 0.25, 0.3333, 0.5, 0.6666, 0.75, 1.0]:
        plt.gca().add_patch(Rectangle((start-(width/2), -0.35), width, 1, fill=fill, color='g', alpha=0.2, zorder=0))  # 0


def scrape_list(obj_list, vals=5):
    lowest_y_values = defaultdict(list)
    vals = 5
    # Iterate through the objects and populate the defaultdict
    for obj in obj_list:
        lowest_y_values[obj.conc_number].append(obj.norm_energy)
    # Create a new list to store objects with lowest y values
    new_objects = []

    # Iterate through the objects and select only those with lowest 5 y values for each x value
    for obj in obj_list:
        if obj.norm_energy in lowest_y_values[obj.conc_number][:vals]:
            new_objects.append(obj)
    return new_objects


# calculate deltas for each structure and each method...
class Data_point:
    """
    Class to handle individual calculations; mostly a convienience class where the methods seem to grow solely on no. results.

    """
    def __init__(self, energy, expected_energy, conc_number, concentration, metal, system, method, index, size=25, color="black", shape="s"):
        self.energy = energy
        self.expected_energy = expected_energy
        self.concentration = concentration
        self.conc_number = conc_number
        self.metal = metal
        self.system = system
        self.method = method
        self.color = color
        self.shape = shape
        self.index= index
        self.size = size

        self.get_energy_norm()
        self.tiny_shift()

    def get_energy_norm(self):
        self.norm_energy = self.energy - self.expected_energy

    def tiny_shift(self):
        """ Shift xvals by a tiny amount to ensure differing calculations are not scraped."""
        for n, method in enumerate(["noU", "no_IVDW", "VDW_1", "VDW_2"]):
            if self.method == method:
                self.conc_number += 1E-07 * (n-2)
        for n, metal in [(-1,"cobalt"), (1,"manganese")]:
            if self.metal == metal:
                self.conc_number += 1E-06 * n

        for n, system in enumerate(["fd", "o2", "o3", "p2"]):
            if self.system == system:
                self.conc_number += 1E-05 * (n-2)

    def __repr__(self):
        return f"{self.metal} {self.method}: Li conc:{self.concentration}  E:{self.energy}, idx: {self.index}"

# calculate deltas and then expected energies for each system metal and method...
# Save the normalised energy of the o3 goodenough system at 0 and full lithiation for all methods and all metals


deltas = {"cobalt":{"fd": {}, "o2": {}, "o3": {}, "p2": {}}, "nickel":{"fd": {}, "o2": {}, "p2": {}, "o3": {}}, "manganese":{"fd": {}, "o2": {}, "o3": {}, "p2": {}}}
zero_li = {"cobalt":{"fd": {}, "o2": {}, "o3": {}, "p2": {}}, "nickel":{"fd": {}, "o2": {}, "p2": {}, "o3": {}}, "manganese":{"fd": {}, "o2": {}, "o3": {}, "p2": {}}}

for metal in metals:
    for space_group in space_groups:
        for method in methods:
            try:
                full_structure = Structure.from_file(filename=f"data/{metal}/{space_group}/full_lithiation/{method}/CONTCAR")
                full_oz = Oszicar(filename=f"data/{metal}/{space_group}/full_lithiation/{method}/OSZICAR")
                norm_full = len(full_structure) / 4
                full_e = full_oz.final_energy / norm_full

                none_structure = Structure.from_file(filename=f"data/{metal}/{space_group}/no_lithiation/{method}/CONTCAR")
                none_oz = Oszicar(filename=f"data/{metal}/{space_group}/no_lithiation/{method}/OSZICAR")
                norm_none = len(none_structure) / 3
                none_e = none_oz.final_energy / norm_none
                delta = round(full_e - none_e, 5)
                deltas[f"{metal}"][f"{space_group}"][f"{method}"] = delta
                zero_li[f"{metal}"][f"{space_group}"][f"{method}"] = none_e
            except:
                pass

data_points = []
for metal in metals:
    for method in methods:
        for space_group in space_groups:
            if space_group == "o2":
                color = "green"
                shape = "o"
            elif space_group == "o3":
                color = "blue"
                shape = "^"
            elif space_group == "p2":
                color = "red"
                shape = "s"
            else:
                color = "purple"
                shape = "D"

            for n, concentration in enumerate(folder_names):
                expected_energy = zero_li[f"{metal}"][f"o3"][f"{method}"] + concs[n] * deltas[f"{metal}"][f"o3"][
                    f"{method}"]

                if concentration in ["full", "no"]:
                    full_dir = f"data/{metal}/{space_group}/{concentration}_lithiation/{method}/"
                    s = Structure.from_file(filename=f"{full_dir}/POSCAR")
                    l = [x for x in s if x.species_string != "Li"]
                    norm_factor = len(l) / 3
                    o = Oszicar(filename=f"{full_dir}/OSZICAR")
                    energy = o.final_energy / norm_factor
                    data_point = Data_point(energy=energy, expected_energy=expected_energy, concentration=concentration,
                                            conc_number=concs[n], shape=shape,
                                            metal=metal, system=space_group, method=method, index="n/a", color=color)
                    data_points.append(data_point)
                else:
                    for directory in range(0, 10):
                        try:

                            full_dir = f"data/{metal}/{space_group}/{concentration}_lithiation/{method}/{directory}"
                            s = Structure.from_file(filename=f"{full_dir}/POSCAR")
                            l = [x for x in s if x.species_string != "Li"]
                            norm_factor = len(l) / 3
                            o = Oszicar(filename=f"{full_dir}/OSZICAR")
                            energy = o.final_energy / norm_factor
                            data_point = Data_point(energy=energy, expected_energy=expected_energy, concentration=concentration, conc_number=concs[n],
                                                    metal=metal, system=space_group, method=method, index=directory, color=color, shape=shape)
                            data_points.append(data_point)
                            #print(f"{expected_energy} Expected, {energy} Actual")
                        except Exception as e:
                            print(f"ERROR:{metal} {concentration} {method} {directory}")
                            pass


dp_2 = [x for x in data_points if -0.3 < x.norm_energy < 0.3] # remove bugged/crappy results
dp_2 = scrape_list(dp_2, vals=3)


shift_var = 0.0125
for metal in metals:
    if metal == "cobalt": s_ = "x in $\mathrm{Li_{x}CoO_{2}}$"
    elif metal == "nickel": s_ = "x in $\mathrm{Li_{x}NiO_{2}}$"
    elif metal == "manganese": s_ = "x in $\mathrm{Li_{x}MnO_{2}}$"

    for method in methods:
        results_list = [x for x in dp_2 if (x.metal == metal and x.method == method)]
        for result in results_list:
            alpha = 0.6
            if result.system == "fd": x_val = result.conc_number + shift_var
            elif result.system == "o2": x_val = result.conc_number + shift_var / 2
            elif result.system == "o3": x_val = result.conc_number - shift_var
            elif result.system == "p2": x_val = result.conc_number - shift_var / 2

            if result.norm_energy > 0.0:
                alpha = max([alpha - (0.5 * result.norm_energy), 0.0])
            fig = plt.scatter(x=x_val, y=result.norm_energy, sizes=[result.size], marker=result.shape, facecolor="None", edgecolors=result.color, alpha=alpha)
        #plt.ylim(-0.3, 0.6)
        plt.ylabel("Energy formula unit/ eV", labelpad=8, fontsize=14)
        plt.xlabel(s_, labelpad=6, fontsize=14)
        plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.4)
        plt.tight_layout()
        plt.yticks(fontsize=12)
        plt.ylim(-0.3, 0.3)
        draw_rectangles(fill=True)
        draw_rectangles(fill=False)
        plt.savefig(f"figures_lq/{metal}_{method}.png", dpi=300)
        plt.close()

