import os
import math
import ctypes
import json
import matplotlib as mpl
import matplotlib.pyplot as plt
from pymatgen.core import Structure

mpl.use('Qt5Agg')

# Conversions.
KB_TO_GIGAPASC = 0.1
KB_TO_MEGAPASC = 100

plt.rcParams['font.family'] = 'Arial'
plt.rcParams['font.size'] = 12
plt.xlabel('X-axis', fontsize=20)
plt.ylabel('Y-axis', fontsize=20)
plt.yticks(fontsize=16)
plt.xlim(-5.51, 5.51) # define the xlim

head = "data/min_cells_response/"


plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)

def get_attribute_values(list_of_objects, attribute): # helper method
    return [getattr(obj, attribute) for obj in list_of_objects]




class Result: # Class to handle a result and plot it easily.
    def __init__(self, metal, pstress, a, b, c, beta, d_1, d_2, vol, i_oo4):
        self.metal = metal
        self.pstress = pstress
        self.a = a
        self.b = b
        self.c = c
        self.beta = beta
        self.d_1 = d_1
        self.d_2 = d_2
        self.vol = vol
        self.i_oo4 = i_oo4

    def get_c_true(self):
        self.c_true = self.c * math.sin(self.beta/57.295) * 3


    def get_color(self):
        if self.metal == "Cobalt":
            self.color = "Blue"
        elif self.metal == "Nickel":
            self.color = "Grey"
        elif self.metal == "Manganese":
            self.color = "Purple" # method for getting key valkues from a result


head = "data/min_cells_response/"

results = []
for metal_folder in ["cobalt_half_lithiation", "manganese_half_lithiation", "nickel_half_lithiation"]:
    if metal_folder == "cobalt_half_lithiation":
        st = "Cobalt"
    elif metal_folder == "manganese_half_lithiation":
        st = "Manganese"
    elif metal_folder == "nickel_half_lithiation":
        st = "Nickel"

    head_dir = f"{head}/{metal_folder}/"
    for pstress in range(-50, 50, 2):
        # Handle negatives in the dumbest way
        if pstress < 0:
            pstress_string = f"A_{abs(pstress)}"
        else:
            pstress_string = f"A{pstress}"
        # Grab contcar as structure
        structure = Structure.from_file(f"{head_dir}/{pstress_string}/CONTCAR")

        I_1o4 = structure.get_distance(i=3, j=4, jimage=0)
        o_1_all = structure[3].frac_coords
        o_1 = structure[3].frac_coords[2]
        o_2 = structure[4].frac_coords[2]
        o_3 = structure[5].frac_coords[2]
        o_4 = structure[6].frac_coords[2]
        mp_o_1 = (o_1 + o_3) / 2
        mp_o_2 = (o_2 + o_4) / 2
        d1 = abs(mp_o_2 - mp_o_1)
        d2 = 1 - d1
        result = Result(metal=st, pstress=pstress,
                        a=structure.lattice.a, b=structure.lattice.b, c=structure.lattice.c,
                        beta=structure.lattice.beta, vol=structure.lattice.volume,
                        d_1=d1 * structure.lattice.c * math.sin(structure.lattice.beta/57.295),
                        d_2=d2 * structure.lattice.c * math.sin(structure.lattice.beta/57.295), i_oo4=I_1o4)
        result.get_color()
        result.get_c_true()
        results.append(result)


def plot_variable_vs_pstress(results, variable_name, y_axis, ticks=None): # Style _1
    plt.xticks(ticks=ticks, fontsize=12)
    plt.ylabel(y_axis, fontsize=20, labelpad=8)
    for result in results: # Unit = GPa
        plt.scatter(x=result.pstress/10, y=getattr(result, variable_name), c=result.color, s=10, marker="s", alpha=0.7)
        plt.show()



# Plot beta third style (Custom yticks, custom ylim)
plot_variable_vs_pstress(results, "b", "$\it{b}$ Axis / Å")
plt.xlabel('Stress / GPa', labelpad=6, fontsize=20)
plt.yticks(ticks=[2.75, 2.80, 2.85, 2.90, 2.95], fontsize=16)
plt.ylim(2.74,2.96)
plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
plt.tight_layout()
ax = plt.gcf().axes
ax[0].patch.set_linewidth(w=1.5)
ax[0].patch.set_edgecolor('black')
ax[0].tick_params(direction="in", labelsize=16)
ax[0].tick_params(width=1.5)
plt.savefig(f"figures_lq/min_cell_b.png", dpi=300) # style 2 custom yticks
plt.savefig(f"figures_hq/min_cell_b.png", dpi=1200) # style 2 custom yticks
plt.close()


# Plot volume style 4 (Custom y ticks to 85 + custom ylim)
plot_variable_vs_pstress(results, "vol", "Volume / $\mathrm{Å^{3}}$")
plt.xlabel('Stress / GPa', labelpad=6, fontsize=20)
plt.tight_layout()
plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
plt.ylim(62,87)
plt.yticks(ticks=[65, 70, 75, 80, 85], fontsize=16)
plt.tight_layout()
ax = plt.gcf().axes
ax[0].patch.set_linewidth(w=1.5)
ax[0].patch.set_edgecolor('black')
ax[0].tick_params(direction="in", labelsize=16)
ax[0].tick_params(width=1.5)
plt.savefig(f"figures_lq/min_cell_vol.png", dpi=300)
plt.savefig(f"figures_hq/min_cell_vol.png", dpi=1200)
plt.close()

# Plot C (non-norm)... Messy because beta and stuff.
plot_variable_vs_pstress(results, "c", "c axis / Å")
plt.xlabel('Stress / GPa', labelpad=6, fontsize=20)
plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
plt.tight_layout()
ax = plt.gcf().axes
ax[0].patch.set_linewidth(w=1.5)
ax[0].patch.set_edgecolor('black')
ax[0].tick_params(direction="in", labelsize=16)
ax[0].tick_params(width=1.5)
plt.savefig(f"figures_lq/min_cell_c_c2m.png", dpi=300)
plt.savefig(f"figures_hq/min_cell_c_c2m.png", dpi=1200)
plt.close()

# Plot C true - axis r3m
plot_variable_vs_pstress(results, "c_true", "c axis / Å")
plt.xlabel('Stress / GPa', labelpad=6, fontsize=20)
plt.tight_layout()
plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
ax = plt.gcf().axes
ax[0].patch.set_linewidth(w=1.5)
ax[0].patch.set_edgecolor('black')
ax[0].tick_params(direction="in", labelsize=16)
ax[0].tick_params(width=1.5)
plt.savefig(f"figures_lq/min_cell_c_r3m.png", dpi=300)
plt.savefig(f"figures_hq/min_cell_c_r3m.png", dpi=1200)
plt.close()

# plot lithium layer spacing title 1 with arrows
plot_variable_vs_pstress(results, "d_1", "Li layer spacing / Å")
plt.xlabel('Stress / GPa', labelpad=6, fontsize=20)
plt.tight_layout()
plt.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
plt.tight_layout()
plt.arrow(x=-0.4, y=2.64, dx=-2.60, dy=0, width=0.0185, head_width=0.060, head_length=0.6, overhang=0.3, color="tab:red", alpha=0.7)
plt.arrow(x=+0.4, y=2.64, dx=+2.60, dy=0, width=0.0185, head_width=0.060, head_length=0.6, overhang=0.3, color="tab:blue", alpha=0.7)
plt.text(x=-1.70, y=2.6575, s="Tension", ha="center", fontsize=15)
plt.text(x=+1.70, y=2.6575, s="Compression", ha="center", fontsize=15)
plt.xlim(-5.51, 5.51)  # define the xlim
plt.ylim(2.58, 3.22)
ax = plt.gcf().axes
ax[0].patch.set_linewidth(w=1.5)
ax[0].patch.set_edgecolor('black')
ax[0].tick_params(direction="in", labelsize=16)
ax[0].tick_params(width=1.5)
plt.plot([0, 0], [-3, 3.5], '--', lw=1.5, alpha=0.65, color="grey")

plt.savefig(f"figures_lq/min_cell_layer_space_arrows.png", dpi=300)
plt.savefig(f"figures_hq/min_cell_layer_space_arrows.png", dpi=1200)
plt.close()

