import matplotlib.pyplot as plt
import numpy as np
import pylab
import matplotlib as mpl
import matplotlib.lines as mlines

mpl.use('Qt5Agg')
m_norm = [("cobalt", -726.461244), ("manganese", -861.405957), ("nickel", -619.202114)]
colors = ["tab:red", "tab:orange", "tab:green", "tab:blue", "tab:purple"]
num_curves = len(colors)


def quick_clean():  # Boiler plate
    plt.rcParams['font.family'] = 'Arial'
    plt.rcParams['font.size'] = 18
    plt.xlabel('X-axis', fontsize=20)
    plt.ylabel('Y-axis', fontsize=20)
    plt.yticks(fontsize=20)
    plt.xticks(fontsize=20)
    plt.xlim(-4.21, 4.21)


def load_data(file="None", normalisation_value=0.0):
    data = np.genfromtxt(f"{file}", delimiter="\t", skip_header=1)
    data = data[5:-5]
    y_col = data[:, 1:]

    y_col = np.array([value - normalisation_value for value in y_col])
    x = data[:, 0]
    x = [val / 10 for val in x]
    return x, y_col


def get_smooth_data(x, y_col, degree=6):
    """

    :param x: x data to plot
    :param y_col:  y data to plot in cols
    :param colors: the colors to use for each plot, should match columns in y_col
    :param degree:  degree polynomial to put through the data defaults 6
    :param alpha: transpancy on the data

    :return: plt.plot
    """
    ys = []
    xs = []
    for i in range(5):
        y = y_col[:, i]
        # plt.scatter(x=x, y=y, c=colors[i], alpha=alpha_points, s=24, marker="D", linewidths=0)
        # Create a smooth function
        coeffs = np.polyfit(x, y, degree)
        poly = np.poly1d(coeffs)
        x_smooth = np.linspace(min(x), max(x), 1000)
        y_smooth = poly(x_smooth)
        ys.append(y_smooth)
        xs.append(x_smooth)
    return xs, ys


fig, (ax_co, ax_ni, ax_mn) = plt.subplots(1, 3, sharex=True, sharey=True, figsize=(18, 6))

x, y_col = load_data(file=f"data/max_cells_response/{m_norm[0][0]}_half_lithiation/static_res.dat",
                     normalisation_value=m_norm[0][1])  # Load Co
smooth_x, y_vals = get_smooth_data(x=x, y_col=y_col)  # Smooth Co
for i, y in enumerate(y_vals):  # Plot Co smooth
    ax_co.plot(smooth_x[0], y, c=colors[i], linewidth=3.0, alpha=0.8)
for i in range(5):
    y = y_col[:, i]
    ax_co.scatter(x=x, y=y, c=colors[i], alpha=0.8, s=30, marker="D", linewidths=0)

x, y_col = load_data(file=f"data/max_cells_response/{m_norm[1][0]}_half_lithiation/static_res.dat",
                     normalisation_value=m_norm[1][1])  # Load Mn
smooth_x, y_vals = get_smooth_data(x=x, y_col=y_col)  # Smooth Mn
for i, y in enumerate(y_vals):  # Plot Mn smooth
    ax_mn.plot(smooth_x[0], y, c=colors[i], linewidth=3.0, alpha=0.8)

for i in range(5): # Plot Mn points
    y = y_col[:, i]
    ax_mn.scatter(x=x, y=y, c=colors[i], alpha=0.8, s=30, marker="D", linewidths=0)  # Plot Mn points

x, y_col = load_data(file=f"data/max_cells_response/{m_norm[2][0]}_half_lithiation/static_res.dat",
                     normalisation_value=m_norm[2][1])  # Load Ni
smooth_x, y_vals = get_smooth_data(x=x, y_col=y_col) # Smooth Ni
for i, y in enumerate(y_vals):  # Plot Ni smooth
    ax_ni.plot(smooth_x[0], y, c=colors[i], linewidth=3.0, alpha=0.8)

for i in range(5): # Plot Ni points
    y = y_col[:, i]
    ax_ni.scatter(x=x, y=y, c=colors[i], alpha=0.8, s=30, marker="D", linewidths=0)

for i in [ax_co, ax_ni, ax_mn]:
    i.axes.set_xlabel('Stress / GPa', fontsize=22, labelpad=6)
    i.grid(True, linestyle='--', linewidth=1.5, alpha=0.4)
    i.axes.tick_params(direction="in", labelsize=20)
    i.axes.tick_params(width=1.5)
    i.patch.set_edgecolor("black")
    i.patch.set_linewidth(1.5)

ax_co.axes.set_ylabel('Energy Difference per unit cell / eV', fontsize=22, labelpad=6)
marker_a = mlines.Line2D([], [], color="tab:red", marker="D", markersize=10, label="Undefected structure (a)", lw=3.0)
marker_b = mlines.Line2D([], [], color="tab:orange", marker="D", markersize=10, lw=3.0,
                         label="$\mathrm{T_{d}}$ Li in vacancy chain (b)")
marker_c = mlines.Line2D([], [], color="tab:green", marker="D", markersize=10, lw=3.0,
                         label="$\mathrm{O_{h}}$ Li in vacancy chain (c)")
marker_d = mlines.Line2D([], [], color="tab:blue", marker="D", markersize=10, lw=3.0,
                         label="$\mathrm{O_{h}}$ and $\mathrm{T_{d}}$ in vacancy chain (d)")
marker_e = mlines.Line2D([], [], color="tab:purple", marker="D", lw=3.0, markersize=10, label="Dumbbell structure (e)")
hands = [marker_a, marker_b, marker_c, marker_d, marker_e]

#legend = ax_ni.axes.legend(handles=hands, frameon=True)
fig.legend(handles=hands[0:2], loc='upper center', bbox_to_anchor=(0.51, 1.04), ncol=2, fontsize=22, frameon=False, handletextpad=0.1, columnspacing=0.9)
fig.legend(handles=hands[2:], loc='upper center', bbox_to_anchor=(0.51, 0.985), ncol=3, fontsize=22, frameon=False, borderpad=0.45, columnspacing=0.9, handletextpad=0.1)


fig.tight_layout()
fig.text(0.10, 0.80, 'i)', ha='right', fontsize=28)
fig.text(0.42, 0.8, 'ii)', ha='right', fontsize=28)
fig.text(0.730, 0.80, 'iii)', ha='right', fontsize=28)
fig.tight_layout()
fig.subplots_adjust(top=0.870)

fig.savefig(f"figures_lq/static_cell_combined.png", dpi=300)
fig.savefig(f"figures_hq/static_cell_combined.png", dpi=1200)
plt.close(fig)