# -*- coding: utf-8 -*-
"""
Created on Mon Nov  4 13:59:39 2019

@author: Daniel Powell
"""

import numpy as np
from matplotlib import pyplot as plt
from scipy.optimize import curve_fit
import cunningham_slip
from numba import njit, prange


rs = np.random.RandomState(np.random.MT19937(np.random.SeedSequence(1994)))


def step(x):
    random = np.random.randint(0, 2)

    if random == 0:
        x_new = x - 1

    else:
        x_new = x + 1

    return x_new


def plot(x):
    time_tot = np.shape(x)[1]
    time = np.arange(0, time_tot, step=1)
    analytical_p = np.sqrt(time)
    analytical_n = -analytical_p
    x_square = np.square(x)
    average = np.sqrt(x_square.mean(axis=0))

    plt.figure()
    # plt.plot(x.transpose())
    plt.xlabel("Time")
    plt.ylabel("Displacement")
    plt.plot(time, analytical_p, 'k--', label="Analytical")
    plt.plot(time, analytical_n, 'k--')
    plt.plot(time, average, 'k:', label='Average')
    plt.legend()
    plt.tight_layout()
    plt.show()


def run(NP=100, NT=1000):
    x = np.zeros((NP, NT))

    for t in range(NT - 1):
        for p in range(NP):
            x[p, t + 1] = step(x[p, t])

    plot(x)


@njit(nogil=True)
def fluent_model(dt, tau_p, T_f, m_p, C_c):
    k_B = 1.380649e-23  # J/K
    chi = np.random.normal(0, 1)  # mean of 0, sd of 1
    accel = chi * np.sqrt(2 * k_B * T_f / (m_p * tau_p * C_c * dt))

    return accel



@njit(nogil=True, parallel=True)
def loop(NT, NP, X, dt, tau_p, T_f, m_p, C_c):
    for p in prange(NP):
        V = 0  # m/s
        for t in range(NT - 1):
            V = fluent_model(dt, tau_p, T_f, m_p, C_c) * dt
            X[p, t + 1] = X[p, t] + (V * dt)

    return X


def big_and_small():
    d_p1 = 1e-6  # m
    rho_p = 2650  # kg/m^3
    mu_f = 1.7894e-5  # kg/ms
    tau_p1 = rho_p * (d_p1 ** 2) / (18 * mu_f)  # s
    m_p1 = (np.pi / 6) * rho_p * (d_p1 ** 3)  # kg
    T_f = 300  # K
    C_c1 = cunningham_slip.C_c_factor(d_p1)

    NP = 100000
    NT = 1000
    dt1 = tau_p1 / 100  # s

    X1 = np.zeros((NP, NT), dtype=np.float)  # m

    X1 = loop(NT, NP, X1, dt1, tau_p1, T_f, m_p1, C_c1)

    x_square1 = np.square(X1)
    mean_square1 = x_square1.mean(axis=0)
    time1 = np.linspace(0, dt1 * NT, num=NT)

    def f(t, D):
        return np.sqrt(2 * D * t)

    fops1 = curve_fit(f, time1, np.sqrt(mean_square1), 1.0)
    D1 = fops1[0][0]
    fit1 = f(time1, D1)
    print(D1)
    

    plt.figure()
    plt.plot(time1 / tau_p1, np.sqrt(mean_square1) / d_p1, '-', label='$1 \mu m$ Ensemble')
    plt.plot(time1 / tau_p1, fit1 / d_p1, '--', label="$1 \mu m$ Analytical")
    
    d_p2 = 1e-7  # m
    tau_p2 = rho_p * (d_p2 ** 2) / (18 * mu_f)  # s
    m_p2 = (np.pi / 6) * rho_p * (d_p2 ** 3)  # kg
    C_c2 = cunningham_slip.C_c_factor(d_p2)

    dt2 = tau_p2 / 100  # s

    X2 = np.zeros((NP, NT), dtype=np.float)  # m

    X2 = loop(NT, NP, X2, dt2, tau_p2, T_f, m_p2, C_c2)

    x_square2 = np.square(X2)
    mean_square2 = x_square2.mean(axis=0)
    time2 = np.linspace(0, dt2 * NT, num=NT)
    
    fops2 = curve_fit(f, time2, np.sqrt(mean_square2), 1.0)
    D2 = fops2[0][0]
    fit2 = f(time2, D2)
    print(D2)

    plt.plot(time2 / tau_p2, np.sqrt(mean_square2) / d_p2, '-', label='$0.1 \mu m$ Ensemble')
    plt.plot(time2 / tau_p2, fit2 / d_p2, '--', label="$0.1 \mu m$ Analytical")
    plt.xlabel(r"$\frac{t}{\tau_p}$", fontsize=16)
    plt.ylabel(r"$\frac{\sqrt{\overline{X^2}}}{d_p}$", fontsize=16)
    plt.legend()
    plt.tight_layout()
    #plt.savefig("brownian_dispacement.eps")
    plt.show()
