# -*- coding: utf-8 -*-
"""
Created on Thu Apr  9 16:25:35 2020

@author: Daniel Powell
"""

import numpy as np
from scipy.stats import linregress
from matplotlib import pyplot as plt
import cunningham_slip
from numba import njit, prange


# =============================================================================
# def f(t, v):
#     return -v
# 
# 
# def sol(t, v):
#     return np.exp(-t)
# =============================================================================

@njit(nogil=True)
def get_RK_ks(f, tn, vn, dt):
    k1 = f(tn, vn)
    k2 = f(tn + ((1 /  5) * dt), vn + ((1 / 5) * dt * k1))
    k3 = f(tn + ((3 / 10) * dt), vn + ((dt / 40) * ((3 * k1) + (9 * k2))))
    k4 = f(tn + ((3 / 5) * dt), vn + (dt * (((3 / 10) * k1) - ((9 / 10) * k2) +
                                            ((6 / 5) * k3))))
    k5 = f(tn + dt, vn + (dt * (((-11 / 54) * k1) + ((5 / 2) * k2) -
                                ((70 / 27) * k3) + ((35 / 27) * k4))))
    k6 = f(tn + ((7 / 8) * dt), vn + (dt * (((1631 / 55296) * k1) +
                                            ((175 / 512) * k2) +
                                            ((575 / 13824) * k3) +
                                            ((44275 / 110592) * k4) +
                                            ((253 / 4096) * k5))))

    return (k1, k2, k3, k4, k5, k6)


@njit(nogil=True)
def RK4_step(f, tn, vn, dt):
    (k1, k2, k3, k4, k5, k6) = get_RK_ks(f, tn, vn, dt)

    zn_new = vn + (dt * (((2825 / 27648) * k1) + ((18575 / 48384) * k3) +
                         ((13525 / 55296) * k4) + ((277 / 14336) * k5) +
                         ((1 / 4) * k6)))

    return zn_new


@njit(nogil=True)
def RK5_step(f, tn, vn, dt):
    (k1, k2, k3, k4, k5, k6) = get_RK_ks(f, tn, vn, dt)

    zn_new = vn + (dt * (((37 / 378) * k1) + ((250 / 621) * k3) +
                         ((125 / 594) * k4) + ((512 / 1771) * k6)))

    return zn_new


@njit(nogil=True)
def RK45_step(f, tn, vn, dt):
    (k1, k2, k3, k4, k5, k6) = get_RK_ks(f, tn, vn, dt)

    zn_new4 = vn + (dt * (((2825 / 27648) * k1) + ((18575 / 48384) * k3) +
                         ((13525 / 55296) * k4) + ((277 / 14336) * k5) +
                         ((1 / 4) * k6)))

    zn_new5 = vn + (dt * (((37 / 378) * k1) + ((250 / 621) * k3) +
                         ((125 / 594) * k4) + ((512 / 1771) * k6)))

    return (zn_new4, zn_new5)


# =============================================================================
# dt_array = np.array([1, 0.5, 0.2, 0.1, 0.05, 0.02, 0.01, 0.005])
# RK4_sol = np.zeros_like(dt_array)
# RK5_sol = np.zeros_like(dt_array)
# t_end = 1
# soln = sol(t_end, 1)
# 
# for i, dt in enumerate(dt_array):
#     NT = int(t_end / dt)
#     v4 = 1
#     v5 = 1
# 
#     for j in range(NT):
#         tn = j * dt
#         v4 = RK4_step(f, tn, v4, dt)
#         v5 = RK5_step(f, tn, v5, dt)
# 
#     RK4_sol[i] = v4
#     RK5_sol[i] = v5
# 
# error_RK4 = np.abs(RK4_sol - soln)
# error_RK5 = np.abs(RK5_sol - soln)
# 
# m4, c4, r4, p4, stderr4 = linregress(np.log10(dt_array), np.log10(error_RK4))
# m5, c5, r5, p5, stderr5 = linregress(np.log10(dt_array), np.log10(error_RK5))
# 
# grad4 = (10 ** c4) * (dt_array ** m4)
# grad5 = (10 ** c5) * (dt_array ** m5)
# 
# 
# plt.figure()
# plt.loglog(dt_array, error_RK4, 'ko', label='$4^{th}$ order')
# plt.loglog(dt_array, error_RK5, 'k^', label='$5^{th}$ order')
# plt.plot(dt_array, grad4, 'k-', label='Gradient = {0}'.format(str(m4)[0:6]))
# plt.plot(dt_array, grad5, 'k--', label='Gradient = {0}'.format(str(m5)[0:6]))
# plt.xlabel(r"$\frac{\Delta t}{t_{max}}$", fontsize=16)
# plt.ylabel(r"$\frac{|v - v_{exact}|}{v_0}$", fontsize=16)
# plt.legend()
# plt.tight_layout()
# plt.savefig("cash_rk45_orders.png", dpi=600)
# plt.show()
# =============================================================================


# =============================================================================
# def fluent_check():
#     d_p = 1e-6
#     rho_p = 2650
#     mu_f = 1.7894e-5
# 
#     tau_p = (rho_p * (d_p ** 2)) / (18 * mu_f)
# 
#     dt = 1.8100418514e-5
#     # v0 = 1
# 
#     def drag(t, X):
#         return np.array([X[1], -X[1] / tau_p])
# 
#     X = np.array([0.001, 1])
# 
#     v4 = RK4_step(drag, 0, X, dt)
#     v5 = RK5_step(drag, 0, X, dt)
# 
#     def exact(t, X, tau_p):
#         v = X[1] * np.exp(-t / tau_p)
#         x = X[0] + (tau_p * X[1] * (1 - np.exp(-t / tau_p)))
#         return np.array([x, v])
# 
#     print(v4)
#     print(v5)
#     
#     error = np.abs(v5 - v4)
# 
#     exact_sol = exact(dt, X, tau_p)
#     
#     # print("error = {0}".format(str(error)))
#     
#     fluent_v = np.array([0.001, 1.3365554667e-01])
#     
#     print(fluent_v)
#     
#     # this is the same as RK5 so FLUENT is using that
# 
#     # accuracy control
#     
#     v4 = RK4_step(drag, 0, fluent_v, dt)
#     v5 = RK5_step(drag, 0, fluent_v, dt)
# 
#     #print(v4)
# 
#     #print("error = {0}".format(str(np.abs(v5 - v4))))
# 
#     return error, exact_sol
# =============================================================================



# corrections time

# using a 1 micron particle
d_p = 1e-6  # m
rho_p = 2650  # kg/m^3
mu_f = 1.789e-5  # kg/ms
C_c = cunningham_slip.C_c_factor(d_p)

tau_p = rho_p * (d_p ** 2) * C_c / (18 * mu_f)  # s

tau_p_inv = 1.0 / tau_p  # /s

g = 9.81  # m/s^2

# assuming V(0) = V_0 and X(0) = 0
V_0 = 1.0  # m/s

@njit(nogil=True)
def drag_only(t, z):
    return np.array([-tau_p_inv * z[0], z[0]])


@njit(nogil=True)
def drag_only_solution_V(t):
    return V_0 * np.exp(-t * tau_p_inv)


@njit(nogil=True)
def drag_only_solution_X(t):
    return tau_p * V_0 * (1 - np.exp(-t * tau_p_inv))


# =============================================================================
# dt_array = np.array([1, 0.5, 0.4, 0.2, 0.1, 0.05]) * tau_p  # s
# RK4_sol = np.zeros((len(dt_array), 2))
# RK5_sol = np.zeros_like(RK4_sol)
# t_end = 10 * tau_p  # s
# solV = drag_only_solution_V(t_end)
# solX = drag_only_solution_X(t_end)
# 
# X_stop = tau_p * V_0
# 
# for i, dt in enumerate(dt_array):
#     NT = int(t_end / dt)
#     z4 = np.array([V_0, 0.0])
#     z5 = np.array([V_0, 0.0])
# 
#     for j in range(NT):
#         tn = (j + 1) * dt
#         z4 = RK4_step(drag_only, tn, z4, dt)
#         z5 = RK5_step(drag_only, tn, z5, dt)
# 
#     # check that it was integrated for the correct amount of time
#     assert tn == t_end
# 
#     RK4_sol[i] = z4
#     RK5_sol[i] = z5
# 
# error_RK4 = np.abs(RK4_sol - solX)
# error_RK5= np.abs(RK5_sol - solX)
# 
# 
# m4, c4, r4, p4, stderr4 = linregress(np.log10(dt_array), np.log10(error_RK4[:, 0]))
# m5, c5, r5, p5, stderr5 = linregress(np.log10(dt_array), np.log10(error_RK5[:, 0]))
# 
# grad4 = (10 ** c4) * (dt_array ** m4)
# grad5 = (10 ** c5) * (dt_array ** m5)
# 
# 
# plt.figure()
# plt.loglog(dt_array / tau_p, error_RK4[:, 0] / V_0, 'ko', label='$4^{th}$ order')
# plt.loglog(dt_array / tau_p, error_RK5[:, 0] / V_0, 'k^', label='$5^{th}$ order')
# plt.plot(dt_array / tau_p, grad4 / V_0, 'k-', label='Gradient = {0}'.format(str(m4)[0:6]))
# plt.plot(dt_array / tau_p, grad5 / V_0, 'k--', label='Gradient = {0}'.format(str(m5)[0:6]))
# plt.xlabel(r"$\frac{\Delta t}{\tau_p}$", fontsize=16)
# plt.ylabel(r"$\frac{|V - V_{exact}|}{V_0}$", fontsize=16)
# plt.legend()
# plt.tight_layout()
# plt.show()
# 
# m4, c4, r4, p4, stderr4 = linregress(np.log10(dt_array), np.log10(error_RK4[:, 1]))
# m5, c5, r5, p5, stderr5 = linregress(np.log10(dt_array), np.log10(error_RK5[:, 1]))
# 
# grad4 = (10 ** c4) * (dt_array ** m4)
# grad5 = (10 ** c5) * (dt_array ** m5)
# 
# 
# plt.figure()
# plt.loglog(dt_array / tau_p, error_RK4[:, 1] / X_stop, 'ko', label='$4^{th}$ order')
# plt.loglog(dt_array / tau_p, error_RK5[:, 1] / X_stop, 'k^', label='$5^{th}$ order')
# plt.plot(dt_array / tau_p, grad4 / X_stop, 'k-', label='Gradient = {0}'.format(str(m4)[0:6]))
# plt.plot(dt_array / tau_p, grad5 / X_stop, 'k--', label='Gradient = {0}'.format(str(m5)[0:6]))
# plt.xlabel(r"$\frac{\Delta t}{\tau_p}$", fontsize=16)
# plt.ylabel(r"$\frac{|X - X_{exact}|}{\tau_p V_0}$", fontsize=16)
# plt.legend()
# plt.tight_layout()
# plt.show()
# =============================================================================

V_0 = 0.0

@njit(nogil=True)
def drag_grav(t, z):
    return np.array([g - (tau_p_inv * z[0]), z[0]])


@njit(nogil=True)
def drag_grav_solution_V(t):
    return tau_p * g * (1 - np.exp(-t * tau_p_inv))


@njit(nogil=True)
def drag_grav_solution_X(t):
    return tau_p * g * (t + (tau_p * (np.exp(-t * tau_p_inv) - 1)))


dt_array = np.array([1, 0.5, 0.4, 0.25, 0.2, 0.1, 0.08, 0.05]) * tau_p  # s
RK4_sol = np.zeros((len(dt_array), 2))
RK5_sol = np.zeros_like(RK4_sol)
t_end = 10.0 * tau_p  # s
solV = drag_grav_solution_V(t_end)
solX = drag_grav_solution_X(t_end)

sol = np.array([solV, solX])

V_terminal = tau_p * g

for i, dt in enumerate(dt_array):
    NT = int(t_end / dt)
    z4 = np.array([V_0, 0.0])
    z5 = np.array([V_0, 0.0])

    for j in range(NT):
        tn = (j + 1) * dt
        z4 = RK4_step(drag_grav, tn, z4, dt)
        z5 = RK5_step(drag_grav, tn, z5, dt)

    # check that it was integrated for the correct amount of time
    assert tn == t_end

    RK4_sol[i] = z4
    RK5_sol[i] = z5

error_RK4 = np.abs(RK4_sol - sol)
error_RK5= np.abs(RK5_sol - sol)

m4, c4, r4, p4, stderr4 = linregress(np.log10(dt_array), np.log10(error_RK4[:, 0]))
m5, c5, r5, p5, stderr5 = linregress(np.log10(dt_array), np.log10(error_RK5[:, 0]))

grad4 = (10 ** c4) * (dt_array ** m4)
grad5 = (10 ** c5) * (dt_array ** m5)


plt.figure()
plt.loglog(dt_array / tau_p, error_RK4[:, 0] / V_terminal, 'ko', label='$4^{th}$ order')
plt.loglog(dt_array / tau_p, error_RK5[:, 0] / V_terminal, 'k^', label='$5^{th}$ order')
plt.plot(dt_array / tau_p, grad4 / V_terminal, 'k-', label='Gradient = {0}'.format(str(m4)[0:6]))
plt.plot(dt_array / tau_p, grad5 / V_terminal, 'k--', label='Gradient = {0}'.format(str(m5)[0:6]))
plt.xlabel(r"$\frac{\Delta t}{\tau_p}$", fontsize=16)
plt.ylabel(r"$\frac{|V - V_{exact}|}{\tau_p g}$", fontsize=16)
plt.legend()
plt.tight_layout()
# plt.savefig("cash_rk45_V.eps")


m4, c4, r4, p4, stderr4 = linregress(np.log10(dt_array), np.log10(error_RK4[:, 1]))
m5, c5, r5, p5, stderr5 = linregress(np.log10(dt_array), np.log10(error_RK5[:, 1]))

grad4 = (10 ** c4) * (dt_array ** m4)
grad5 = (10 ** c5) * (dt_array ** m5)


plt.figure()
plt.loglog(dt_array / tau_p, error_RK4[:, 1] / solX, 'ko', label='$4^{th}$ order')
plt.loglog(dt_array / tau_p, error_RK5[:, 1] / solX, 'k^', label='$5^{th}$ order')
plt.plot(dt_array / tau_p, grad4 / solX, 'k-', label='Gradient = {0}'.format(str(m4)[0:6]))
plt.plot(dt_array / tau_p, grad5 / solX, 'k--', label='Gradient = {0}'.format(str(m5)[0:6]))
plt.xlabel(r"$\frac{\Delta t}{\tau_p}$", fontsize=16)
plt.ylabel(r"$|\frac{X - X_{exact}}{X_{exact}}|$", fontsize=16)
plt.legend()
plt.tight_layout()
# plt.savefig("cash_rk45_X.eps")
plt.show()
