from matplotlib.colors import LinearSegmentedColormap
import matplotlib.pyplot as plt
import matplotlib.transforms
import matplotlib.path
import matplotlib as mpl
import numpy as np
from matplotlib.collections import LineCollection
from scipy.interpolate.interpolate import make_interp_spline
mpl.use('Qt5Agg')


def quick_clean(): # Boiler plate
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 14
    plt.xlabel('X-axis', fontsize=18)
    plt.ylabel('Y-axis', fontsize=18)
    plt.yticks(fontsize=14)
    plt.xticks([0, 2, 4, 6, 8], ["A", "B", "C", "D", "E"])
    plt.xlim(-0.25, 8.25)

# This data was preprocessed from vaspruns.xml using bash and Visual Basic - is the normalised energies for each pressure
data_cobalt = np.genfromtxt(f"data/max_cells_response/cobalt_half_lithiation/interpolated_cleaned.dat", delimiter=",", skip_header=1)
data_nickel = np.genfromtxt("data/max_cells_response/nickel_half_lithiation/interpolated_cleaned.dat", delimiter=",", skip_header=1)
data_manganese = np.genfromtxt("data/max_cells_response/manganese_half_lithiation/interpolated_cleaned.dat", delimiter=",", skip_header=1)

smooth_fitting = 1500 # Points on cubic spline to fit to
colormap = "RdYlBu" # Colormap

fig, (ax_co, ax_ni, ax_mn) = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(18, 6))
d_ = data_cobalt[:, 1:] # Load Co
cmap = plt.cm.get_cmap(colormap, len(data_cobalt))

for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data) - 1, smooth_fitting)  # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        ax_co.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)

d_ = data_nickel[:, 1:] # Load Ni
cmap = plt.cm.get_cmap(colormap, len(data_nickel))
for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data)-1, smooth_fitting) # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        ax_ni.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)

d_ = data_manganese[:, 1:] # load Mn
cmap = plt.cm.get_cmap(colormap, len(data_manganese))
for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data)-1, smooth_fitting) # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        ax_mn.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)

for i in [ax_co, ax_ni, ax_mn]:
    i.axes.set_xlabel('Pathway Step', fontsize=22, labelpad=6)
    i.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
    i.axes.tick_params(direction="in", labelsize=20)
    i.axes.set_ylim(-0.20, 1.4)
    i.axes.set_xlim(-0.25, 8.25)
    i.axes.set_xticks([0.0, 2.0,4.0,6.0,8.0])
    i.axes.set_xticklabels(["A", "B", "C", "D", "E"])
    i.axes.tick_params(width=1.5)
    i.patch.set_edgecolor("black")
    i.patch.set_linewidth(1.5)
ax_co.axes.set_ylabel('Energy Difference per unit cell / eV', fontsize=22, labelpad=6)

cax = fig.add_axes([0.15, 0.895, 0.7, 0.03])
sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlBu, norm=plt.Normalize(vmin=-4, vmax=4))
sm.set_array([])  # Required for the colorbar to map correctly

# Add the colorbar
cbar = fig.colorbar(sm, cax=cax, orientation='horizontal', ticks=[-3.6, 0 , 3.6])
cbar.set_label('', labelpad=0)  # Adjust label position
cbar.set_ticklabels(["4 kB Tension", "Equilbrium", "4 kB Compression"], fontsize=24)
# Move the colorbar tick values above the colorbar
cax.xaxis.set_ticks_position('top')
fig.tight_layout()
fig.subplots_adjust(top=0.870)

fig.savefig(f"figures_lq/neb_colormap_combined.png", dpi=300)  # Style 1, no custom yticks
fig.savefig(f"figures_hq/neb_colormap_combined.png", dpi=1200)
plt.close(fig)

d_ = data_cobalt[:, 1:]  # Load Co
cmap = plt.cm.get_cmap(colormap, len(data_cobalt))
for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data)-1, smooth_fitting) # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        plt.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)
quick_clean()
plt.xlabel('Pathway Step', fontsize=18, labelpad=6)
plt.ylabel('Energy Difference per unit cell / eV', fontsize=18, labelpad=6)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
plt.tight_layout()

plt.savefig(f"figures_lq/cobalt_neb_{colormap}.png", dpi=300)  # Style 1, no custom yticks
plt.savefig(f"figures_hq/cobalt_neb_{colormap}.png", dpi=1200)  # Style 1, no custom yticks
plt.close()

d_ = data_nickel[:, 1:] # Load Ni
cmap = plt.cm.get_cmap(colormap, len(data_nickel))
for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data)-1, smooth_fitting) # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        plt.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)
quick_clean()
plt.xlabel('Pathway Step', fontsize=18, labelpad=6)
plt.ylabel('Energy Difference per unit cell / eV', fontsize=18, labelpad=6)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
plt.tight_layout()
plt.savefig(f"figures_lq/nickel_neb_{colormap}.png", dpi=300)  # Style 1, no custom yticks
plt.savefig(f"figures_hq/nickel_neb_{colormap}.png", dpi=1200)  # Style 1, no custom yticks
plt.close()

d_ = data_manganese[:, 1:] # load Mn
cmap = plt.cm.get_cmap(colormap, len(data_manganese))
for i, row_data in enumerate(d_):
    if i % 1 == 0:
        xs = np.linspace(0, len(row_data)-1, smooth_fitting) # Change number of points to adjust smoothness
        spl = make_interp_spline(np.arange(len(row_data)), row_data, k=3)
        ys = spl(xs)
        color_val = i
        plt.plot(xs, ys, color=cmap(color_val), linewidth=0.25, alpha=0.95)
quick_clean()
plt.xlabel('Pathway Step', fontsize=18, labelpad=6)
plt.ylabel('Energy Difference per unit cell / eV', fontsize=18, labelpad=6)
plt.grid(True, linestyle='--', linewidth=0.5, alpha=0.4)
plt.tight_layout()
plt.savefig(f"figures_lq/manganese_neb_{colormap}.png", dpi=300)  # Style 1, no custom yticks
plt.savefig(f"figures_hq/manganese_neb_{colormap}.png", dpi=1200)  # Style 1, no custom yticks
plt.close()

