# -*- coding: utf-8 -*-
"""
Created on Tue Dec  1 16:31:38 2020

@author: Daniel Powell
"""

import numpy as np
from matplotlib import pyplot as plt
from numba import jit, njit, prange, float64
import numba as nb


@njit(nogil=True)
def taylor_green_velocity(x, y, umag, a):
    """
    Computes and returns the 2D Taylor-Green fluid velocity at a point (x, y),
    for a given umag and a.
    
    u_x = umag cos(ax) sin(ay)
    u_y = -umag sin(ax) cos(ay)

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).

    Returns
    -------
    tg_vel : array of float
        x and y fluid velocity components (m/s).

    """
    tg_vel = umag * np.array([np.cos(a * x) * np.sin(a * y),
                              -np.sin(a * x) * np.cos(a * y)])

    return tg_vel


#@njit(float64[:](float64, float64, float64, float64, float64, float64[:]), nogil=True)
def taylor_green_ee_velocity(x, y, umag, a, tau_p, u):
    """
    Computes and returns the 2D Taylor-Green EE velocity at a point (x, y),
    for a given umag and a.
    
    u_x = umag cos(ax) sin(ay)
    u_y = -umag sin(ax) cos(ay)

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p: float
        particle relaxation time (s).
    u: array of float
        local fluid velocity vector (m/s).

    Returns
    -------
    ee_vel : array of float
        x and y EE velocity components (m/s).

    """
    ee_rel = np.array([np.sin(2.0 * a * x), np.sin(2.0 * a * y)])
    ee_rel *= 0.5 * tau_p * umag * umag * a

    ee_vel = u + ee_rel

    return ee_vel


@njit(nogil=True)
def taylor_green_mee_velocity(x, y, umag, a, tau_p, u):
    """
    Computes and returns the 2D Taylor-Green MEE velocity at a point (x, y),
    for a given umag and a.
    
    u_x = umag cos(ax) sin(ay)
    u_y = -umag sin(ax) cos(ay)

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p: float
        particle relaxation time (s).
    u: array of float
        local fluid velocity vector (m/s).

    Returns
    -------
    mee_vel : array of float
        x and y MEE velocity components (m/s).

    """
    mee_vel = u
    coeff = (tau_p * (umag * umag) * a) / (
        2 + (tau_p * tau_p * umag * umag * a * a * (np.cos(2 * a * x) +
                                                    np.cos(2 * a * y))))
    mee_vel += coeff * np.array([np.sin(2 * a * x) +
                                 (tau_p * umag * a * ((np.sin(2 * a * x) *
                                                       np.sin(a * x) *
                                                       np.sin(a * y)) +
                                                      (np.sin(2 * a * y) *
                                                       np.cos(a * x) *
                                                       np.cos(a * y)))),
                                 np.sin(2 * a * y) -
                                 (tau_p * umag * a * ((np.sin(2 * a * x) *
                                                       np.cos(a * x) *
                                                       np.cos(a * y)) +
                                                      (np.sin(2 * a * y) *
                                                       np.sin(a * x) *
                                                       np.sin(a * y))))])

    return mee_vel


@njit(nogil=True)
def taylor_green_ee_strain_rate(x, y, umag, a, tau_p):
    """
    Computes and returns the 2D EE Taylor-Green strain rate magnitude at a
    point (x, y), for a given umag, a and tau_p.
    
    u_x = umag cos(ax) sin(ay)
    u_y = -umag sin(ax) cos(ay)

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p: float
        particle relaxation time (s).

    Returns
    -------
    s_mag_v : float
        particle velocity field strain rate magnitude (/s).

    """
    term1 = 4.0 * (umag ** 2) * (a ** 2) * (np.sin(a * x) ** 2) * (np.sin(a * y) ** 2)
    term2 = 2.0 * (tau_p ** 2) * (umag ** 4) * (a ** 4) * ((np.cos(2 * a * x) ** 2) + (np.cos(2 * a * y) ** 2))
    term3 = 4.0 * tau_p * (umag ** 3) * (a ** 3) * np.sin(a * x) * np.sin(a * y) * (np.cos(2 * a * y) - np.cos(2 * a * x))

    s_mag_v = np.sqrt(term1 + term2 + term3)

    return s_mag_v


#@njit(nogil=True)
def get_all_velocities_analytical(x, y, umag, a, tau_p):
    """
    Computes and returns the fluid velocity, the analytical solution for the
    Equilibrium Eulerian velocity and the relative velocity at a given point in
    the 2D Taylor-Green flow.

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p : float
        particle relaxation time (s).

    Returns
    -------
    fluid_vel : array of float
        x and y fluid velocity components (m/s).
    ee_vel : array of float
        x and y particle velocity components (m/s).
    relative_vel : array of float
        x and y relative velocity components (m/s).

    """
    fluid_vel = taylor_green_velocity(x, y, umag, a)

    ee_vel = taylor_green_ee_velocity(x, y, umag, a, tau_p, fluid_vel)

    relative_vel = ee_vel - fluid_vel

    return fluid_vel, ee_vel, relative_vel


@njit(nogil=True)
def get_all_velocities_analytical_mee(x, y, umag, a, tau_p):
    """
    Computes and returns the fluid velocity, the analytical solution for the
    Modified Equilibrium Eulerian velocity and the relative velocity at a given point in
    the 2D Taylor-Green flow.

    Parameters
    ----------
    x : float
        x-coordinate (m).
    y : float
        y-coordinate (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p : float
        particle relaxation time (s).

    Returns
    -------
    fluid_vel : array of float
        x and y fluid velocity components (m/s).
    ee_vel : array of float
        x and y particle velocity components (m/s).
    relative_vel : array of float
        x and y relative velocity components (m/s).

    """
    fluid_vel = taylor_green_velocity(x, y, umag, a)

    mee_vel = taylor_green_mee_velocity(x, y, umag, a, tau_p, fluid_vel)

    relative_vel = mee_vel - fluid_vel

    return fluid_vel, mee_vel, relative_vel


def create_uniform_grid(N, l):
    """
    Computes and returns a square uniform mesh of length 2l, centred at the
    origin aligned with the x and y directions.

    Parameters
    ----------
    N : int
        number of gridpoints in each coordinate direction.
    l : float
        eddy diameter (m).

    Returns
    -------
    X : array of float
        cell centroid x coordinates (m).
    Y : array of float
        cell centroid y coordinates (m).
    XF : array of float
        face centroid x coordinates (m).
    YF : array of float
        face centroid y coordinates (m).
    dx : float
        gridspacing (m).
    Area : float
        face area (m^2).
    V : float
        cell volume (m^3).

    """
    # includes faces and centroids
    all_points = np.linspace(-l, l, num=(2 * N) + 1, dtype=np.float64)
    
    # get the cell centroids
    x = all_points[1:-1:2]
    # faces centroids of the western face for each cell, with an extra point
    # for the eastern face of the final cell
    xf = all_points[::2]

    # grid spacings
    dx = x[1] - x[0]  # m
    Area = dx ** 2  # m^2
    V = dx ** 3  # m^3

    # create 2D arrays
    X, Y = np.meshgrid(x, x, indexing='ij')
    XF, YF = np.meshgrid(xf, xf, indexing='ij')

    return (X, Y, XF, YF, dx, Area, V)


@njit(nogil=True)
def get_compass(i, j, N):
    """
    Computes and returns the neighbour indices for a given cell indexed with
    (i, j) for the structured, fully-periodic mesh.
    
    The west and east indices provide the 'i' component of the west/east
    neighbours.

    The north and south indices provide the 'j' component of the north/south
    neighbours.

    Parameters
    ----------
    i : int
        x-index of cell P.
    j : int
        y-index of cell P.
    N : int
        number of cells in each coordinate direction.

    Returns
    -------
    west : int
        x-index of cell W.
    east : int
        x-index of cell E.
    north : int
        y-index of cell N.
    south : int
        y-index of cell S.

    """
    if j == 0:
        south = N - 1
        north = 1
    elif (j + 1 == N):
        south = j - 1
        north = 0
    else:
        south = j - 1
        north = j + 1

    if i == 0:
        west = N - 1
        east = 1
    elif (i + 1 == N):
        west = i - 1
        east = 0
    else:
        west = i - 1
        east = i + 1

    return west, east, north, south


#@njit(nogil=True, parallel=True)
def velocities_and_neighbours(X, Y, N, umag, a, tau_p):
    """
    Computes and returns the fluid velocities, EE velocities, relative
    velocities and the array of neighbour indices for the entire domain.

    Parameters
    ----------
    X : array of float
        x meshgrid of cell centroids.
    Y : array of float
        y meshgrid of cell centroids.
    N : int
        number of cells in each coordinate direction.
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p : float
        particle relaxation time (s).

    Returns
    -------
    u : array of float
        fluid velocity array (m/s).
    v : array of float
        particle velocity array (m/s).
    w : array of float
        relative velocity array (m/s).
    neighbours : array of int
        array of neighbour indices.

    """
    # fluid velocity
    u = np.zeros((N, N, 2), dtype=np.float64)
    # particle EE velocity
    v = np.zeros((N, N, 2), dtype=np.float64)
    # relative velocity w = v - u
    w = np.zeros((N, N, 2), dtype=np.float64)
    # neighbour indices (west, east, north, south)
    neighbours = np.zeros((N, N, 4), dtype=np.int64)
    
    for i in prange(N):
        for j in prange(N):
            # calculate velocities (these remain constant)
            u[i, j, :], v[i, j, :], w[i, j, :] = get_all_velocities_analytical(X[i, j], Y[i, j], umag, a, tau_p)
            neighbours[i, j, :] = get_compass(i, j, N)

    return u, v, w, neighbours


@njit(nogil=True, parallel=True)
def velocities_and_neighbours_mee(X, Y, N, umag, a, tau_p):
    """
    Computes and returns the fluid velocities, MEE velocities, relative
    velocities and the array of neighbour indices for the entire domain.

    Parameters
    ----------
    X : array of float
        x meshgrid of cell centroids (m).
    Y : array of float
        y meshgrid of cell centroids (m).
    N : int
        number of cells in each coordinate direction.
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).
    tau_p : float
        particle relaxation time (s).

    Returns
    -------
    u : array of float
        fluid velocity array (m/s).
    v : array of float
        particle velocity array (m/s).
    w : array of float
        relative velocity array (m/s).
    neighbours : array of int
        array of neighbour indices (m/s).

    """
    # fluid velocity
    u = np.zeros((N, N, 2), dtype=nb.float64)
    # particle EE velocity
    v = np.zeros((N, N, 2), dtype=nb.float64)
    # relative velocity w = v - u
    w = np.zeros((N, N, 2), dtype=nb.float64)
    # neighbour indices (west, east, north, south)
    neighbours = np.zeros((N, N, 4), dtype=nb.int64)
    
    for i in prange(N):
        for j in prange(N):
            # calculate velocities (these remain constant)
            u[i, j, :], v[i, j, :], w[i, j, :] = get_all_velocities_analytical_mee(X[i, j], Y[i, j], umag, a, tau_p)
            neighbours[i, j, :] = get_compass(i, j, N)

    return u, v, w, neighbours


@njit(float64(float64, float64, float64), nogil=True)
def mass_fraction_to_volume_fraction(c, rho_p, rho_f):
    """
    Computes and returns the volume fraction of the particulate phase given the
    mass fraction and densities.

    Parameters
    ----------
    c : float
        particle mass fraction.
    rho_p : float
        particulate phase density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).

    Returns
    -------
    alpha : float
        particle volume fraction.

    """
    alpha = (rho_f * c) / (rho_p + ((rho_f - rho_p) * c))

    return alpha


@njit(float64(float64, float64, float64), nogil=True)
def volume_fraction_to_mass_fraction(alpha, rho_p, rho_f):
    """
    Computes and returns the mass fraction of the particulate phase given the
    volume fraction and densities.

    Parameters
    ----------
    alpha : float
        particle volume fraction.
    rho_p : float
        particulate phase density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).

    Returns
    -------
    c : float
        particle mass fraction.

    """
    c = (rho_p * alpha) / ((rho_p * alpha) + ((1 - alpha) * rho_f))

    return c


@njit(float64(float64, float64), nogil=True)
def volume_fraction_to_number_density(alpha, d_p):
    """
    Computes and returns the number density of the particulate phase given the
    volume fraction and particle diameter.

    Parameters
    ----------
    alpha : float
        particle volume fraction.
    d_p : float
        particle diameter (m).

    Returns
    -------
    n : float
        particle number density (#/m^3).

    """
    n = 6 * alpha / (np.pi * (d_p ** 3))

    return n


@njit(float64(float64, float64), nogil=True)
def number_density_to_volume_fraction(n, d_p):
    """
    Computes and returns the volume fraction of the particulate phase given the
    number density and particle diameter.

    Parameters
    ----------
    n : float
        particle number density (#/m^3).
    d_p : float
        particle diameter (m).

    Returns
    -------
    alpha : float
        particle volume fraction.

    """
    alpha = np.pi * (d_p ** 3) * n / 6

    return alpha


@njit(float64(float64, float64, float64), nogil=True)
def hertzian_max_contact_area(v_rel, d_p, rho_p):
    """
    Computes and returns the maximum contact area predicted by Hertzian contact
    theory for two spheres colliding.

    Parameters
    ----------
    v_rel : flaot
        relative impact velocity (m/s).
    d_p : float
        particles diameter (m).
    rho_p : float
        particle density (kg/m^3).

    Returns
    -------
    A_max : float
        maximum contact area (m^2).

    """
    r = d_p / 2  # particle radius (m)
    m = (np.pi / 6) * rho_p * (d_p ** 3)  # particle mass (kg)
    E = 7.13e10  # Young's modulus (Pa)
    nu = 0.17  # Poisson's ratio

    r_star = r / 2
    m_star = m / 2
    E_star = E / (2 * (1 - (nu ** 2)))

    A_max = 2 * np.pi * r_star * (((15 * m_star) / (32 * E_star * np.sqrt(r_star))) ** 0.4) * (v_rel ** 0.8)

    return A_max


@njit(float64(float64, float64, float64), nogil=True)
def hertzian_contact_time(v_rel, d_p, rho_p):
    """
    Computes and returns the contact time predicted by Hertzian contact
    theory for two spheres colliding.

    Parameters
    ----------
    v_rel : flaot
        relative impact velocity (m/s).
    d_p : float
        particles diameter (m).
    rho_p : float
        particle density (kg/m^3).

    Returns
    -------
    t_c : float
        contact time (s).

    """
    I = 1.4716375921623521  # integral (hertzian_integral.py)
    E = 7.13e10  # Young's modulus (Pa)
    nu = 0.17  # Poisson's ratio

    t_c = I * (((5 * np.pi * np.sqrt(2) / 4) * rho_p * ((1 - (nu ** 2)) / E)) ** 0.4) * d_p * (v_rel ** -0.2)

    return t_c


@njit(nogil=True)
def check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p):
    """
    Checks that all the coefficients of the linear system are strictly
    non-negative, that a_P > 0 and S_p <= 0 for all entries.

    Parameters
    ----------
    a_P : array of float
        central coefficient.
    a_E : array of float
        east coefficient.
    a_W : array of float
        west coefficient.
    a_N : array of float
        north coefficient.
    a_S : array of float
        south coefficient.
    S_p : array of float
        implicit source term contribution.

    Returns
    -------
    None.

    """
    # check a_P always strictly positive
    assert np.all(a_P > 0)

    # check neighbour coefficients non-negative
    assert np.all(a_E >= 0)
    assert np.all(a_W >= 0)
    assert np.all(a_N >= 0)
    assert np.all(a_S >= 0)

    # check implicit source term component not positive
    assert np.all(S_p <= 0)


@njit(nogil=True)
def upwind(i, j, flux_func, u, Area, X, Y, dx, umag, a, west, east, south, north,
           a_W, a_E, a_S, a_N, alpha, neighbours, source):
    """
    Computes and returns the flow rate coefficient on the east/north face
    of cell P and the west/south face of cell E/N respectively using the
    first-order upwind scheme.

    Parameters
    ----------
    mdot : float
        flow rate at face.

    Returns
    -------
    current_cell : float
        flow rate coefficient for cell P.
    neighbour_cell : float
        flow rate coefficient for cell E/N.

    """
    vdot_e = flux_func(i, j, east, j, u, alpha, 'east', Area, X, Y, dx, umag, a)
    vdot_n = flux_func(i, j, i, north, u, alpha, 'north', Area, X, Y, dx, umag, a)
    # current coeff, neighbour coeff
    a_E[i, j] = max(-vdot_e, 0.0)
    a_W[east, j] = max(vdot_e, 0.0)
    a_N[i, j] = max(-vdot_n, 0.0)
    a_S[i, north] = max(vdot_n, 0.0)


@njit(nogil=True)
def linear_upwind(i, j, flux_func, u, Area, X, Y, dx, umag, a, west, east,
                  south, north, a_W, a_E, a_S, a_N, alpha, neighbours, source):
    """
    Computes a 2nd order linear upwind approximation to a variable var,
    assuming a uniform grid.

    """
    # compute face fluxes
    flux_e = flux_func(i, j, east, j, u, alpha, 'east', Area, X, Y, dx, umag, a)
    flux_n = flux_func(i, j, i, north, u, alpha, 'north', Area, X, Y, dx, umag, a)

    # get neighbour indices (ignore p1, p2, p3)
    p1, EE, p2, p3 = neighbours[east, j]
    p1, p2, NN, p3 = neighbours[i, north]

    # upwind neighbour coefficients

    a_E[i, j] = max(-flux_e, 0.0)
    a_W[east, j] = max(flux_e, 0.0)
    a_N[i, j] = max(-flux_n, 0.0)
    a_S[i, north] = max(flux_n, 0.0)

    if flux_e > 0:
        source[i, j] += -0.5 * flux_e * (alpha[i, j] - alpha[west, j])
        source[east, j] += -0.5 * flux_e * (alpha[i, j] - alpha[west, j])
    else:
        source[i, j] += 0.5 * flux_e * (alpha[east, j] - alpha[EE, j])
        source[east, j] += 0.5 * flux_e * (alpha[east, j] - alpha[EE, j])

    if flux_n > 0:
        source[i, j] += -0.5 * flux_n * (alpha[i, j] - alpha[i, south])
        source[i, north] += -0.5 * flux_n * (alpha[i, j] - alpha[i, south])
    else:
        source[i, j] += 0.5 * flux_n * (alpha[i, north] - alpha[i, NN])
        source[i, north] += 0.5 * flux_n * (alpha[i, north] - alpha[i, NN])


@njit(nogil=True)
def min_mod(r):
    psi =  max(0.0, min(r, 1.0))

    return psi


@njit(nogil=True)
def SUPERBEE(r):
    psi = max(0.0, min(2.0 * r, 1.0), min(r, 2))

    return psi


@njit(nogil=True)
def linear_upwind_limited(i, j, flux_func, u, Area, X, Y, dx, umag, a, west,
                          east, south, north, a_W, a_E, a_S, a_N, alpha,
                          neighbours, source, limiter):
    """
    Computes a 2nd order linear upwind approximation to a variable var,
    assuming a uniform grid. The flux is limited using a choice of limiter.
    This should NOT be used as a convective scheme and instead use a scheme
    with the limiter name in the function name.

    """
    # compute face fluxes
    flux_e = flux_func(i, j, east, j, u, alpha, 'east', Area, X, Y, dx, umag, a)
    flux_n = flux_func(i, j, i, north, u, alpha, 'north', Area, X, Y, dx, umag, a)

    # get neighbour indices (ignore p1, p2, p3)
    p1, EE, p2, p3 = neighbours[east, j]
    p1, p2, NN, p3 = neighbours[i, north]

    # upwind neighbour coefficients

    a_E[i, j] = max(-flux_e, 0.0)
    a_W[east, j] = max(flux_e, 0.0)
    a_N[i, j] = max(-flux_n, 0.0)
    a_S[i, north] = max(flux_n, 0.0)

    r = 0.0
    psi = 0.0

    if flux_e > 0:
        # slope ratio
        if np.abs(alpha[east, j] - alpha[i, j]) < 1e-15:
            psi = 1.0
        else:
            r = (alpha[i, j] - alpha[west, j]) / (alpha[east, j] - alpha[i, j])
            # apply limiter
            psi = limiter(r)

        source[i, j] += -0.5 * flux_e * psi * (alpha[east, j] - alpha[i, j])
    else:
        # slope ratio
        if np.abs(alpha[east, j] - alpha[i, j]) < 1e-15:
            psi = 1.0
        else:
            r = (alpha[east, j] - alpha[EE, j]) / (alpha[i, j] - alpha[east, j])
            # apply limiter
            psi = limiter(r)

        source[i, j] += 0.5 * flux_e * psi *  (alpha[east, j] - alpha[i, j])

    if flux_n > 0:
        # slope ratio
        if np.abs(alpha[i, north] - alpha[i, j]) < 1e-15:
            psi = 1.0
        else:
            r = (alpha[i, j] - alpha[i, south]) / (alpha[i, north] - alpha[i, j])
            # apply limiter
            psi = limiter(r)

        source[i, j] += -0.5 * flux_n * psi * (alpha[i, north] - alpha[i, j])
    else:
        # slope ratio
        if np.abs(alpha[i, north] - alpha[i, j]) < 1e-15:
            psi = 1.0
        else:
            r = (alpha[i, north] - alpha[i, NN]) / (alpha[i, j] - alpha[i, north])
            # apply limiter
            psi = limiter(r)

        source[i, j] += 0.5 * flux_n * psi * (alpha[i, north] - alpha[i, j])


@njit(nogil=True)
def linear_upwind_min_mod(i, j, flux_func, u, Area, X, Y, dx, umag, a, west,
                          east, south, north, a_W, a_E, a_S, a_N, alpha,
                          neighbours, source):
    """
    Computes a 2nd order linear upwind approximation to a variable var,
    assuming a uniform grid. The flux is limited using a min-mod scheme to
    provide a 2nd order TVD scheme.

    """
    linear_upwind_limited(i, j, flux_func, u, Area, X, Y, dx, umag, a, west,
                              east, south, north, a_W, a_E, a_S, a_N, alpha,
                              neighbours, source, min_mod)


@njit(nogil=True)
def linear_upwind_SUPERBEE(i, j, flux_func, u, Area, X, Y, dx, umag, a, west,
                          east, south, north, a_W, a_E, a_S, a_N, alpha,
                          neighbours, source):
    """
    Computes a 2nd order linear upwind approximation to a variable var,
    assuming a uniform grid. The flux is limited using a SUPERBEE scheme to
    provide a 2nd order TVD scheme.

    """
    linear_upwind_limited(i, j, flux_func, u, Area, X, Y, dx, umag, a, west,
                              east, south, north, a_W, a_E, a_S, a_N, alpha,
                              neighbours, source, SUPERBEE)


@njit(nogil=True)
def QUICK(i, j, flux_func, u, Area, X, Y, dx, umag, a, west, east, south, north,
          a_W, a_E, a_S, a_N, alpha, neighbours, source):
    """
    Computes a QUICK approximation to a variable var, between
    the gridpoints P and n assuming a uniform grid.

    """
    # compute face fluxes
    flux_e = flux_func(i, j, east, j, u, alpha, 'east', Area, X, Y, dx, umag, a)
    flux_n = flux_func(i, j, i, north, u, alpha, 'north', Area, X, Y, dx, umag, a)

    # get neighbour indices (ignore p1, p2, p3)
    p1, EE, p2, p3 = neighbours[east, j]
    p1, p2, NN, p3 = neighbours[i, north]

    # upwind neighbour coefficients

    a_E[i, j] = max(-flux_e, 0.0)
    a_W[east, j] = max(flux_e, 0.0)
    a_N[i, j] = max(-flux_n, 0.0)
    a_S[i, north] = max(flux_n, 0.0)

    if flux_e > 0:
        source[i, j] += 0.125 * flux_e * (-alpha[west, j] - (2 * alpha[i, j]) + (3 * alpha[east, j]))
    else:
        source[i, j] += 0.125 * flux_e * ((3 * alpha[i, j]) - (2 * alpha[east, j]) - alpha[EE, j])

    if flux_n > 0:
        source[i, j] += 0.125 * flux_n * (-alpha[i, south] - (2 * alpha[i, j]) + (3 * alpha[i, north]))
    else:
        source[i, j] += 0.125 * flux_n * ((3 * alpha[i, j]) - (2 * alpha[i, north]) - alpha[i, NN])
    

@njit(nogil=True)
def central_difference_uniform(P_i, n_i, P_j, n_j, var):
    """
    Computes a central differencing approximation to a variable var, between
    the gridpoints P and n assuming a uniform grid

    Parameters
    ----------
    P_i : int
        i index of cell P.
    n_i : int
        i index of cell n.
    P_j : int
        j index of cell P.
    n_j : int
        j index of cell n.
    var : array of float
        array holding the value of the variable.

    Returns
    -------
    var_face : float/array of float
        central difference approximation to the variable at the face.

    """
    var_face = 0.5 * (var[P_i, P_j] + var[n_i, n_j])

    return var_face


@njit(nogil=True)
def volumetric_flow_cd_uniform(P_i, P_j, n_i, n_j, u, alpha, face, Area, X, Y, dx,
                               umag, a):
    """
    Computes a central differencing approximation to the mass flow rate on the
    face between the gridpoints P and n assuming a uniform grid

    Parameters
    ----------
    P_i : int
        i index of cell P.
    P_j : int
        j index of cell P.
    n_i: int
        index of neighbour cell.
    n_j: int
        index of neighbour cell.
    face : str
        strjg indictating whether this is the 'east' face or 'north' face.
    Area : float
        face area (m^2).
    X : array of float
        x meshgrid of cell centroids.
    Y : array of float
        y meshgrid of cell centroids.
    dx : float
        gridspacing (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).

    Returns
    -------
    vdot : float
        volumetric flow rate on face.

    """
    u_face = central_difference_uniform(P_i, n_i, P_j, n_j, u)
    if (face == 'east') or (face == 'west'):
        vdot = u_face[0] * Area
    else:
        vdot = u_face[1] * Area

    return vdot


@njit(nogil=True)
def volumetric_flow_exact(P_i, P_j, n_i, n_j, u, alpha, face, Area, X, Y, dx, umag,
                          a):
    """
    Computes and returns the analytical solution to the volumetric flow rate on
    the face.

    Parameters
    ----------
    P_i : int
        i index of cell P.
    P_j : int
        j index of cell P.
    n_i: int
        index of neighbour cell.
    n_j: int
        index of neighbour cell.
    face : str
        strjg indictating whether this is the 'east' face or 'north' face.
    Area : float
        face area (m^2).
    X : array of float
        x meshgrid of cell centroids.
    Y : array of float
        y meshgrid of cell centroids.
    dx : float
        gridspacing (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).

    Returns
    -------
    vdot : float
        volumetric flow rate on face.

    """
    if face == 'east':
        X_face = X[P_i, P_j] + (dx / 2)
        Y_face = Y[P_i, P_j]
        u_face = umag * np.cos(a * X_face) * np.sin(a * Y_face)
    elif face == 'west':
        X_face = X[P_i, P_j] - (dx / 2)
        Y_face = Y[P_i, P_j]
        u_face = umag * np.cos(a * X_face) * np.sin(a * Y_face)
    elif face == 'north':
        X_face = X[P_i, P_j]
        Y_face = Y[P_i, P_j] + (dx / 2)
        u_face = -umag * np.sin(a * X_face) * np.cos(a * Y_face)
    else:
        X_face = X[P_i, P_j]
        Y_face = Y[P_i, P_j] - (dx / 2)
        u_face = -umag * np.sin(a * X_face) * np.cos(a * Y_face)

    vdot = u_face * Area

    return vdot


@njit(nogil=True)
def charge_flux(P_i, P_j, n_i, n_j, u, alpha, face, Area, X, Y, dx, umag, a):
    """
    Computes and returns the analytical solution to the volumetric flow rate on
    the face and multiplies this by a central differencing approximation to the
    volume fraction on the face.

    Parameters
    ----------
    P_i : int
        i index of cell P.
    P_j : int
        j index of cell P.
    n_i: int
        index of neighbour cell.
    n_j: int
        index of neighbour cell.
    face : str
        strjg indictating whether this is the 'east' face or 'north' face.
    Area : float
        face area (m^2).
    X : array of float
        x meshgrid of cell centroids.
    Y : array of float
        y meshgrid of cell centroids.
    dx : float
        gridspacing (m).
    umag : float
        maximum fluid velocity (m/s).
    a : float
        flow parameter (/m).

    Returns
    -------
    vdot : float
        volumetric flow rate on face.

    """
    if face == 'east':
        X_face = X[P_i, P_j] + (dx / 2)
        Y_face = Y[P_i, P_j]
        u_face = umag * np.cos(a * X_face) * np.sin(a * Y_face)
    elif face == 'west':
        X_face = X[P_i, P_j] - (dx / 2)
        Y_face = Y[P_i, P_j]
        u_face = umag * np.cos(a * X_face) * np.sin(a * Y_face)
    elif face == 'north':
        X_face = X[P_i, P_j]
        Y_face = Y[P_i, P_j] + (dx / 2)
        u_face = -umag * np.sin(a * X_face) * np.cos(a * Y_face)
    else:
        X_face = X[P_i, P_j]
        Y_face = Y[P_i, P_j] - (dx / 2)
        u_face = -umag * np.sin(a * X_face) * np.cos(a * Y_face)

    vdot = u_face * Area

    alpha_f = central_difference_uniform(P_i, n_i, P_j, n_j, alpha)

    flux = vdot * alpha_f

    return flux


@njit(float64(float64, float64), nogil=True)
def first_order_backward_euler(V, dt):
    """
    Computes and returns the a_P coefficient from a first order backwards
    Euler scheme

    Parameters
    ----------
    rho_p : float
        particle density (kg/m^3, unused).
    V : float
        cell volume (m^3).
    dt : float
        time step (s).

    Returns
    -------
    coeff : TYPE
        DESCRIPTION.

    """
    coeff = V / dt

    return coeff


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_none(x_f, y_f, phi_f, umag, a, tau_p, d_p, dx):
    """
    Returns zero for the diffusion coefficient.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    c_f : float
        face interpolated mass fraction.
    rho_p : float
        particle density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    D = 0.0

    return D


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_brownian(x_f, y_f, phi_f, umag, a, tau_p, d_p, dx):
    """
    Computes and returns the Brownian diffusion coefficient.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    c_f : float
        face interpolated mass fraction.
    rho_p : float
        particle density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    # no Cunningham slip correction
    T_f = 300.0  # K
    k_B = 1.380649e-23  # J/K
    mu_f = 1.7894e-5  # kg/ms

    return k_B * T_f / (3 * np.pi * mu_f * d_p)


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_turb(x_f, y_f, phi_f, umag, a, tau_p, d_p, dx):
    """
    Computes and returns the LES turbulent diffusion coefficient.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    c_f : float
        face interpolated mass fraction.
    rho_p : float
        particle density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    # turbulent Schmidt number = 1
    S_mag_v = taylor_green_ee_strain_rate(x_f, y_f, umag, a, tau_p)
    C_s = 0.1

    return (dx * C_s) * (dx * C_s) * S_mag_v


@njit(float64(float64, float64), nogil=True)
def particle_mean_free_path(d_p, alpha):
    """
    Computes and returns the particle mean free path from Gidaspow.

    Parameters
    ----------
    d_p : float
        particle diameter (m).
    alpha : float
        volume fraction.

    Returns
    -------
    l : float
        particle mean free path (m).

    """
    # the constant 1 / 6 sqrt(2) is precalculated to save processing time
    constant = 0.1178511301977579
    # mean free path
    l = constant * d_p / alpha

    return l


@njit(float64(float64), nogil=True)
def packing_function(alpha):
    # close packing limit (pi / (3 sqrt(2)))
    alpha_max = 0.7404804896930609

    chi = 1.0 / (1 - ((alpha / alpha_max) ** (1.0 / 3.0)))

    return chi


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_collisions(x_f, y_f, alpha_f, umag, a, tau_p, d_p, dx):
    """
    Computes and returns the collisional diffusion coefficient.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    c_f : float
        face interpolated mass fraction.
    rho_p : float
        particle density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    # calculate EE particle velocity strain rate magnitude
    S_mag_v = taylor_green_ee_strain_rate(x_f, y_f, umag, a, tau_p)
    
    # get the dense correction
    g_0 = packing_function(alpha_f)
    
    # prevent the particle mean free path from becoming very large
    lambda_p = min(0.11785113019775793 * d_p / (alpha_f * g_0), dx)

    # representative velocity
    v_rep = dx * S_mag_v

    # sqrt(2) / 12
    D_coll = lambda_p * v_rep

    return D_coll


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_collisions_no_correction(x_f, y_f, alpha_f, umag, a,
                                                   tau_p, d_p, dx):
    """
    Computes and returns the collisional diffusion coefficient with no
    prevention of incredibly large mean free paths due to vanishing volume
    fraction.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    c_f : float
        face interpolated mass fraction.
    rho_p : float
        particle density (kg/m^3).
    rho_f : float
        fluid density (kg/m^3).
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    # calculate EE particle velocity strain rate magnitude
    S_mag_v = taylor_green_ee_strain_rate(x_f, y_f, umag, a, tau_p)
    
    # get the dense correction
    g_0 = packing_function(alpha_f)
    
    # particle mean free path
    lambda_p = 0.11785113019775793 * d_p / (alpha_f * g_0)

    # representative velocity
    v_rep = dx * S_mag_v

    # sqrt(2) / 12
    D_coll = lambda_p * v_rep

    return D_coll


@njit(float64(float64, float64, float64, float64, float64, float64, float64,
              float64), nogil=True)
def diffusion_coefficient_brownian_and_collisions(x_f, y_f, alpha_f, umag, a,
                                                  tau_p, d_p, dx):
    """
    Computes and returns the sum of the Brownian and collisional diffusion
    coefficient.

    Parameters
    ----------
    x_f : float
        face centroid x-coordinate (m).
    y_f : float
        face centroid y-coordinate (m).
    alpha_f : float
        face interpolated volume fraction.
    umag : float
        maximum Taylor-Green velocity (m/s).
    a : float
        Taylor-Green flow parameter (/m).
    tau_p : float
        particle relaxation time (s).
    d_p : float
        particle diameter (m).

    Returns
    -------
    D : float
        diffusion coefficient (m^2/s).

    """
    D_brownian = diffusion_coefficient_brownian(x_f, y_f, alpha_f, umag, a,
                                                tau_p, d_p, dx)
    D_collisions = diffusion_coefficient_collisions(x_f, y_f, alpha_f, umag, a,
                                                    tau_p, d_p, dx)

    D_total = D_brownian + D_collisions

    return D_total


@jit(nogil=True)
def source_zero(i, j, east, west, north, south, w, phi, Area,
                                 source, S_p, limit=0.0):
    """
    Computes and returns the explicit source term on a uniform grid.

    Parameters
    ----------
    i : int
        cell centroid i-index.
    j : int
        cell centroid j-index.
    east : int
        east neighbour cell centroid i-index.
    west : int
        west neighbour cell centroid i-index..
    north : int
        north neighbour cell centroid j-index.
    south : int
        south neighbour cell centroid j-index.
    w : array of float
        relative velocity array.
    phi : array of float
        volume fraction array.
    Area : float
        cell face area (m^2).
    limit : float, optional
        mass fraction limit. The default is 0.0.

    Returns
    -------
    S_c : float
        explicit part of source term (kg/s).
    S_p : float
        implicit part of source term = 0.0 (kg/s).

    """
    pass


@jit(nogil=True)
def source_uniform_grid_explicit(i, j, east, west, north, south, w, alpha, Area,
                                 source, S_p, limit=0.0):
    """
    Computes and returns the explicit source term on a uniform grid.

    Parameters
    ----------
    i : int
        cell centroid i-index.
    j : int
        cell centroid j-index.
    east : int
        east neighbour cell centroid i-index.
    west : int
        west neighbour cell centroid i-index..
    north : int
        north neighbour cell centroid j-index.
    south : int
        south neighbour cell centroid j-index.
    w : array of float
        relative velocity array.
    alpha : array of float
        volume fraction array.
    Area : float
        cell face area (m^2).
    limit : float, optional
        mass fraction limit. The default is 0.0.

    Returns
    -------
    S_c : float
        explicit part of source term (kg/s).
    S_p : float
        implicit part of source term = 0.0 (kg/s).

    """
    source[i, j] += (0.5 * Area) * ((w[west, j, 0] * alpha[west, j]) -
                          (w[east, j, 0] * alpha[east, j]) +
                          (w[i, south, 1] * alpha[i, south]) -
                          (w[i, north, 1] * alpha[i, north]))

    if alpha[i, j] < limit:
        source[i, j] = 0.0

    S_p[i, j] += 0.0


@njit(nogil=True)
def source_uniform_grid_implicit(i, j, east, west, north, south, w, alpha, Area,
                                 source, S_p, limit=0.0):
    """
    Computes and returns the implicit and explicit components of the source
    term on a uniform grid.

    Parameters
    ----------
    i : int
        cell centroid i-index.
    j : int
        cell centroid j-index.
    east : int
        east neighbour cell centroid i-index.
    west : int
        west neighbour cell centroid i-index..
    north : int
        north neighbour cell centroid j-index.
    south : int
        south neighbour cell centroid j-index.
    w : array of float
        relative velocity array.
    alpha : array of float
        volume fraction array.
    Area : float
        cell face area (m^2).
    limit : float, optional
        mass fraction limit. The default is 0.0.

    Returns
    -------
    S_c : float
        explicit part of source term (kg/s).
    S_p : float
        implicit part of source term (kg/s).

    """
    source[i, j] += (-0.5 * Area * ((w[i, j, 0] * (alpha[east, j] - alpha[west, j])) + (w[i, j, 1] * (alpha[i, north] - alpha[i, south])))) + \
                                 max(-0.5 * Area * (w[east, j, 0] - w[west, j, 0] + w[i, north, 1] - w[i, south, 1]), 0.0) * alpha[i, j]
    S_p[i, j] += min(-0.5 * Area * (w[east, j, 0] - w[west, j, 0] + w[i, north, 1] - w[i, south, 1]), 0.0)

    # limits
    # turn source off if volume fraction below limit
    if alpha[i, j] < limit:
        source[i, j] = 0.0
        S_p[i, j] = 0.0


@njit(float64(float64, float64, float64, float64), nogil=True)
def ireland_charge_transfer(A_c, sigma_0, tau, t_c):
    return A_c * sigma_0 * (1 - np.exp(-t_c / tau))


@njit(nogil=True)
def source_charge_implicit(i, j, east, west, north, south, w, alpha, q, Area,
                           source, S_p, x, y, umag, a, tau_p, rho_p, d_p,
                           limit=0.0):
    """
    Computes and returns the implicit and explicit components of the source
    term on a uniform grid.

    Parameters
    ----------
    i : int
        cell centroid i-index.
    j : int
        cell centroid j-index.
    east : int
        east neighbour cell centroid i-index.
    west : int
        west neighbour cell centroid i-index..
    north : int
        north neighbour cell centroid j-index.
    south : int
        south neighbour cell centroid j-index.
    w : array of float
        relative velocity array.
    alpha : array of float
        volume fraction array.
    Area : float
        cell face area (m^2).
    limit : float, optional
        mass fraction limit. The default is 0.0.

    Returns
    -------
    S_c : float
        explicit part of source term (kg/s).
    S_p : float
        implicit part of source term (kg/s).

    """
    # tribo constants
    # tau = 1091 # s
    # doing tau_corr * t_c (t_c from 0.55 m/s impact hertzian_aerosol_beam.py)
    # tau = 4.516e-07
    # doing the same using plate velocity
    tau = 1.426e-7
    sigma_0 = 4.413e-6  # C/m^2


    S_mag_v = taylor_green_ee_strain_rate(x, y, umag, a, tau_p)  # /s
    V = Area ** 1.5  # m^3
    v_rel = S_mag_v * np.sqrt(Area)  # m/s

    # Hertzian contact constants
    A_c = hertzian_max_contact_area(v_rel, d_p, rho_p)  # m^2
    # using surface area
    #A_c = np.pi * d_p * d_p
    t_c = hertzian_contact_time(v_rel, d_p, rho_p)  # s

    # charge transfer from a single collision (C)
    Delta_Q = ireland_charge_transfer(A_c, sigma_0, tau, t_c)

    # number density at cell centroid (p/m^3)
    n = volume_fraction_to_number_density(alpha[i, j], d_p)

    # charge source term (C/kgs)
    charge_source = 4 * (n ** 2) * (d_p ** 3) * S_mag_v * Delta_Q / (3 * rho_p)

    S_p[i, j] += min(-0.5 * Area * ((alpha[east, j] * w[east, j, 0]) -
                                   (alpha[west, j] * w[west, j, 0]) +
                                   (alpha[i, north] * w[i, north, 1]) -
                                   (alpha[i, south] * w[i, south, 1])), 0.0)

    source[i, j] += (-0.5 * Area * alpha[i, j] * ((w[i, j, 0] * (q[east, j] - q[west, j])) + (w[i, j, 1] * (q[i, north] - q[i, south])))) + \
                                 (max(-0.5 * Area * ((alpha[east, j] * w[east, j, 0]) -
                                   (alpha[west, j] * w[west, j, 0]) +
                                   (alpha[i, north] * w[i, north, 1]) -
                                   (alpha[i, south] * w[i, south, 1])), 0.0) * q[i, j]) + \
                                     (charge_source * V)

    if q[i, j] < limit:
        S_p[i, j] = 0.0
        if source[i, j] < 0.0:
            source[i, j] = 0.0


@njit(nogil=True)
def source_electic_potential(i, j, east, west, north, south, w, alpha, q, Area,
                           source, S_p, x, y, umag, a, tau_p, rho_p, d_p,
                           limit=0.0):
    # air
    epsilon = 1.0006 * 8.8541878128e-12

    # charge density of all particles in cell (C/m^3)
    # rho_e = (6 * alpha[i, j] * q[i, j]) / (np.pi * d_p * d_p * d_p)
    rho_e = rho_p * alpha[i, j] * q[i, j]

    S_p[i, j] += 0.0
    source[i, j] += -rho_e * (Area ** 1.5) / epsilon



@njit(nogil=True, parallel=True)
def calculate_a_neighbours(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u,
                           alpha, X, Y, dx, Area, umag, a, tau_p, d_p, w,
                           flux_func, source_func, diff_func, conv_func, V, dt,
                           alpha_old, limit):
    """
    Computes and returns the neighbour coefficients, along with both the
    explicit and implicit parts of the source term for a steady simulation.
    """
    # zero all source term arrays before updating values
    for i in prange(N):
        for j in prange(N):
            source[i, j] = 0.0
            S_p[i, j] = 0.0

    for i in prange(N):
        for j in prange(N):
            # get neighbour indices
            west, east, north, south = neighbours[i, j]
            
            # SOURCE TERM
            source_func(i, j, east, west, north, south, w, alpha, Area, source,
                        S_p, limit)

            # CONVECTION
            conv_func(i, j, flux_func, u, Area, X, Y, dx, umag, a, west, east,
                      south, north, a_W, a_E, a_S, a_N, alpha, neighbours, source)

            # DIFFUSION
            # get face coordinates
            x_e = X[i, j] + (0.5 * dx)
            y_e = Y[i, j]
            x_n = X[i, j]
            y_n = Y[i, j] + (0.5 * dx)

            # interpolate volume fraction to faces
            alpha_e = central_difference_uniform(i, east, j, j, alpha)
            alpha_n = central_difference_uniform(i, i, j, north, alpha)

            # calculate diffusion coefficients at faces
            D_col_e = diff_func(x_e, y_e, alpha_e, umag, a, tau_p, d_p, dx)
            D_col_n = diff_func(x_n, y_n, alpha_n, umag, a, tau_p, d_p, dx)
    
            # include diffusion contribution
            coeff_e = dx * D_col_e
            coeff_n = dx * D_col_n
            a_E[i, j] += coeff_e
            a_W[east, j] += coeff_e
            a_N[i, j] += coeff_n
            a_S[i, north] += coeff_n

    return a_E, a_W, a_N, a_S, source, S_p


@njit(nogil=True, parallel=True)
def calculate_a_neighbours_charge(a_E, a_W, a_N, a_S, S_p, source, neighbours,
                                  N, u, alpha, q, X, Y, dx, Area, umag, a,
                                  tau_p, d_p, w, V, rho_p, limit):
    """
    Computes and returns the neighbour coefficients, along with both the
    explicit and implicit parts of the source term for a steady simulation.
    """
    for i in prange(N):
        for j in prange(N):
            # get neighbour indices
            west, east, north, south = neighbours[i, j]
            x = X[i, j]
            y = Y[i, j]
            
            # SOURCE TERM
            source_charge_implicit(i, j, east, west, north, south, w, alpha, q,
                                   Area, source, S_p, x, y, umag, a, tau_p,
                                   rho_p, d_p, limit=0.0)

            # CONVECTION
            upwind(i, j, charge_flux, u, Area, X, Y, dx, umag, a, west, east,
                      south, north, a_W, a_E, a_S, a_N, alpha, neighbours, source)

    return a_E, a_W, a_N, a_S, source, S_p


@njit(nogil=True, parallel=True)
def calculate_a_neighbours_electric(a_E, a_W, a_N, a_S, S_p, source, neighbours,
                                  N, u, alpha, q, X, Y, dx, Area, umag, a,
                                  tau_p, d_p, w, V, rho_p, limit):
    """
    Computes and returns the neighbour coefficients, along with both the
    explicit and implicit parts of the source term for a steady simulation.
    """
    for i in prange(N):
        for j in prange(N):
            # get neighbour indices
            west, east, north, south = neighbours[i, j]
            x = X[i, j]
            y = Y[i, j]
            
            # SOURCE TERM
            source_electic_potential(i, j, east, west, north, south, w, alpha, q,
                                   Area, source, S_p, x, y, umag, a, tau_p,
                                   rho_p, d_p, limit=0.0)

            # DIFFUSION
            # calculate diffusion coefficients at faces
            D_col = 1.0
    
            # include diffusion contribution
            coeff = dx * D_col
            a_E[i, j] = coeff
            a_W[east, j] = coeff
            a_N[i, j] = coeff
            a_S[i, north] = coeff

    return a_E, a_W, a_N, a_S, source, S_p


@njit(nogil=True, parallel=True)
def calculate_a_neighbours_unsteady(a_E, a_W, a_N, a_S, S_p, source,
                                    neighbours, N, u, alpha, X, Y, dx, Area,
                                    umag, a, tau_p, d_p, w,
                                    flux_func, source_func, diff_func,
                                    conv_func, V, dt, alpha_old, limit=0.0):
    """
    Computes and returns the neighbour coefficients, along with both the
    explicit and implicit parts of the source term for an unsteady simulation.
    """
    a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours(a_E, a_W, a_N, a_S, S_p, source, neighbours, N,
                           u, alpha, X, Y, dx, Area, umag, a, tau_p,
                           d_p, w, flux_func, source_func, diff_func, conv_func,
                           V, dt, alpha_old, limit)

    for i in prange(N):
        for j in prange(N):
            source[i, j] += first_order_backward_euler(V, dt) * alpha_old[i, j]

    return a_E, a_W, a_N, a_S, source, S_p


@njit(nogil=True, parallel=True)
def calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt):
    """
    Computes and returns the central coefficient for a steady simulation.
    """
    for i in prange(N):
        for j in prange(N):
            a_P[i, j] = a_E[i, j] + a_W[i, j] + a_N[i, j] + a_S[i, j] - S_p[i, j]

    return a_P


@njit(nogil=True, parallel=True)
def calculate_a_P_unsteady(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt):
    """
    Computes and returns the central coefficient for an unsteady simulation.
    """
    for i in prange(N):
        for j in prange(N):
            a_P[i, j] = a_E[i, j] + a_W[i, j] + a_N[i, j] + a_S[i, j] - \
                S_p[i, j] + first_order_backward_euler(V, dt)

    return a_P


@njit(nogil=True, cache=True)
def Gauss_Seidel_point_by_point(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours):
    """
    Solves the linear system (a_P + a_E + a_W + a_N + a_S) alpha = S_c using
    the standard point by point Gauss-Seidel method.
    """
    for i in range(N):
        for j in range(N):
            west, east, north, south = neighbours[i, j]
            alpha[i, j] = (((a_E[i, j] * alpha[east, j]) + (a_W[i, j] * alpha[west, j]) +
                       (a_N[i, j] * alpha[i, north]) + (a_S[i, j] * alpha[i, south])) + source[i, j]) / a_P[i, j]

    return alpha


@njit(nogil=True, cache=True)
def Gauss_Seidel_point_by_point_relax(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax=1.0):
    """
    Solves the linear system (a_P + a_E + a_W + a_N + a_S) alpha = S_c using
    the point by point Gauss-Seidel method with an optional relaxation factor.
    """
    for i in range(N):
        for j in range(N):
            west, east, north, south = neighbours[i, j]
            gauss_seidel = (((a_E[i, j] * alpha[east, j]) + (a_W[i, j] * alpha[west, j]) +
                       (a_N[i, j] * alpha[i, north]) + (a_S[i, j] * alpha[i, south])) + source[i, j]) / a_P[i, j]

            alpha[i, j] = (relax * gauss_seidel) + ((1 - relax) * alpha[i, j])

    return alpha


@njit(nogil=True, parallel=True, cache=True)
def Gauss_Seidel_parallel(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, num_part=2):
    for odd in prange(2):
        # checkerboard solving
        for i in range(N):
            for j in range(odd, N, 2):
                west, east, north, south = neighbours[i, j]
                alpha[i, j] = (((a_E[i, j] * alpha[east, j]) + (a_W[i, j] * alpha[west, j]) +
                           (a_N[i, j] * alpha[i, north]) + (a_S[i, j] * alpha[i, south])) + source[i, j]) / a_P[i, j]

    return alpha
            


@njit(nogil=True)
def Gauss_Seidel_relax_quick(alpha, a_P, a_E, a_EE, a_W, a_WW, a_N, a_NN, a_S, a_SS, source, N, neighbours, relax=1.0):
    """
    Solves the linear system
    """
    for i in range(N):
        for j in range(N):
            # ignore p1, p2 and p3
            west, east, north, south = neighbours[i, j]
            p1, EE, p2, p3 = neighbours[east, j]
            WW, p1, p2, p3 = neighbours[west, j]
            p1, p2, NN, p3 = neighbours[i, north]
            p1, p2, p3, SS = neighbours[i, south]
            gauss_seidel = (((a_E[i, j] * alpha[east, j]) + (a_EE[i, j] * alpha[EE, j]) + (a_W[i, j] * alpha[west, j]) +
                       (a_WW[i, j] * alpha[WW, j]) + (a_N[i, j] * alpha[i, north]) + (a_NN[i, j] * alpha[i, NN]) + (a_S[i, j] * alpha[i, south]) +
                            (a_SS[i, j] * alpha[i, SS])) + source[i, j]) / a_P[i, j]

            alpha[i, j] = (relax * gauss_seidel) + ((1 - relax) * alpha[i, j])

    return alpha


@njit(nogil=True, parallel=True)
def calculate_scaled_residual(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours):
    """
    Calculate the normalised (scaled) residual for the volume fraction.
    """
    sum_top = 0.0
    sum_bot = 0.0
    for i in prange(N):
        for j in prange(N):
            west, east, north, south = neighbours[i, j]
            sum_top += np.abs((a_E[i, j] * alpha[east, j]) + (a_W[i, j] * alpha[west, j]) +
                (a_N[i, j] * alpha[i, north]) + (a_S[i, j] * alpha[i, south]) + source[i, j] -
                    (a_P[i, j] * alpha[i, j]))

            sum_bot += np.abs(a_P[i, j] * alpha[i, j])

    R = sum_top / sum_bot

    return R


@njit(nogil=True, parallel=True)
def calculate_scaled_residual_quick(alpha, a_P, a_E, a_EE, a_W, a_WW, a_N, a_NN, a_S, a_SS, source, N, neighbours):
    """
    Calculate the normalised (scaled) residual for the volume fraction.
    """
    sum_top = 0.0
    sum_bot = 0.0
    for i in prange(N):
        for j in prange(N):
            west, east, north, south = neighbours[i, j]
            p1, EE, p2, p3 = neighbours[east, j]
            WW, p1, p2, p3 = neighbours[west, j]
            p1, p2, NN, p3 = neighbours[i, north]
            p1, p2, p3, SS = neighbours[i, south]
            sum_top += np.abs((a_E[i, j] * alpha[east, j]) + (a_EE[i, j] * alpha[EE, j]) + (a_W[i, j] * alpha[west, j]) +
                (a_WW[i, j] * alpha[WW, j]) + (a_N[i, j] * alpha[i, north]) + (a_NN[i, j] * alpha[i, NN]) + (a_S[i, j] * alpha[i, south]) + 
                (a_SS[i, j] * alpha[i, SS]) + source[i, j] - (a_P[i, j] * alpha[i, j]))

            sum_bot += np.abs(a_P[i, j] * alpha[i, j])

    R = sum_top / sum_bot

    return R


@njit(nogil=True, parallel=True)
def it_resi(N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
            a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p, Nit,
            relax=1.0, limit=0.0, flux_func=volumetric_flow_exact,
            source_func=source_uniform_grid_implicit,
            diff_func=diffusion_coefficient_none, conv_func=upwind,
            a_nb_func=calculate_a_neighbours, a_P_func=calculate_a_P,
            res_criterion=1e-15):
    dt = None
    alpha_old=None
    residual = np.zeros(Nit, dtype=nb.float64)
    alpha_max = np.zeros(Nit, dtype=nb.float64)
    alpha_min = np.zeros(Nit, dtype=nb.float64)
    alpha_tot = np.zeros(Nit, dtype=nb.float64)
    r = np.sqrt((X ** 2) + (Y ** 2))
    mean_alpha = np.zeros(Nit, dtype=nb.float64)
    moment_2 = np.zeros(Nit, dtype=nb.float64)
    moment_3 = np.zeros(Nit, dtype=nb.float64)
    moment_4 = np.zeros(Nit, dtype=nb.float64)

    for i in range(Nit):
        alpha = Gauss_Seidel_point_by_point_relax(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

        # recalculate coefficients
        a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha, X, Y,
                                                                 dx, Area, umag, a, tau_p, d_p, w, flux_func, source_func,
                                                                 diff_func, conv_func, V, dt, alpha_old, limit)
        a_P = a_P_func(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

        residual[i] = calculate_scaled_residual(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)

        # calculate stats
        alpha_max[i], alpha_min[i], alpha_tot[i] = calculate_alpha_quantities(alpha)
        mean_alpha[i], moment_2[i], moment_3[i], moment_4[i] = r_moments(r, alpha, alpha_tot[i])

        if alpha_min[i] < 0:
            #print("Negative volume fraction detected. Exiting...")
            for i in prange(N):
                for j in prange(N):
                    if alpha[i, j] < 0.0:
                        alpha[i, j] = 0.0
            #break

        # if converged
        if residual[i] < res_criterion:
            # check if we did less iterations than the maximum
            if i < (Nit - 1):
                # get rid off padding zeroes on the arrays
                residual = residual[:i + 1]
                alpha_max = alpha_max[:i + 1]
                alpha_min = alpha_min[:i + 1]
                alpha_tot = alpha_tot[:i + 1]
                mean_alpha = mean_alpha[:i + 1]
                moment_2 = moment_2[:i + 1]
                moment_3 = moment_3[:i + 1]
                moment_4 = moment_4[:i + 1]

            break

    if residual[i] > res_criterion:
        print("Convergence criterion not reached after " + str(i + 1) + " iterations.")



    return (alpha, residual, alpha_max, alpha_min, alpha_tot, mean_alpha,
            moment_2, moment_3, moment_4, a_P, a_E, a_W, a_N, a_S, source, S_p)


@njit(nogil=True, parallel=True)
def it_charge_resi(N, umag, a, X, Y, dx, Area, V, alpha, q, u, v, w,
                   neighbours, a_P, a_E, a_W, a_N, a_S, source, S_p, rho_f,
                   rho_p, tau_p, d_p, Nit, relax=1.0, limit=0.0,
                   res_criterion=1e-15):
    dt = None
    residual = np.zeros(Nit, dtype=nb.float64)
    q_max = np.zeros(Nit, dtype=nb.float64)
    q_min = np.zeros(Nit, dtype=nb.float64)
    q_tot = np.zeros(Nit, dtype=nb.float64)
    r = np.sqrt((X ** 2) + (Y ** 2))
    mean_q = np.zeros(Nit, dtype=nb.float64)
    moment_2 = np.zeros(Nit, dtype=nb.float64)
    moment_3 = np.zeros(Nit, dtype=nb.float64)
    moment_4 = np.zeros(Nit, dtype=nb.float64)

    for i in range(Nit):
        q = Gauss_Seidel_point_by_point_relax(q, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

        # recalculate coefficients
        a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_charge(a_E, a_W, a_N, a_S, S_p, source, neighbours,
                                  N, u, alpha, q, X, Y, dx, Area, umag, a,
                                  tau_p, d_p, w, V, rho_p, limit)
        a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

        residual[i] = calculate_scaled_residual(q, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)

        # calculate stats
        q_max[i], q_min[i], q_tot[i] = calculate_alpha_quantities(q)
        mean_q[i], moment_2[i], moment_3[i], moment_4[i] = r_moments(r, q, q_tot[i])

        # if converged
        if residual[i] < res_criterion:
            # check if we did less iterations than the maximum
            if i < (Nit - 1):
                # get rid off padding zeroes on the arrays
                residual = residual[:i + 1]
                q_max = q_max[:i + 1]
                q_min = q_min[:i + 1]
                q_tot = q_tot[:i + 1]
                mean_q = mean_q[:i + 1]
                moment_2 = moment_2[:i + 1]
                moment_3 = moment_3[:i + 1]
                moment_4 = moment_4[:i + 1]

            break

    if residual[i] > res_criterion:
        print("Convergence criterion not reached after " + str(i + 1) + " iterations.")



    return (q, residual, q_max, q_min, q_tot, mean_q, moment_2, moment_3,
            moment_4, a_P, a_E, a_W, a_N, a_S, source, S_p)


@njit(nogil=True, parallel=True)
def it_charge_resume(N, umag, a, X, Y, dx, Area, V, alpha, q, u, v, w,
                   neighbours, a_P, a_E, a_W, a_N, a_S, source, S_p, rho_f,
                   rho_p, tau_p, d_p, Nit, residual0, q_max0, q_min0, q_tot0,
                   mean_q0, moment_20, moment_30, moment_40, relax=1.0,
                   limit=0.0, res_criterion=1e-15):
    dt = None
    old_Nit = len(residual0)
    residual = np.zeros(Nit + old_Nit, dtype=nb.float64)
    q_max = np.zeros(Nit + old_Nit, dtype=nb.float64)
    q_min = np.zeros(Nit + old_Nit, dtype=nb.float64)
    q_tot = np.zeros(Nit + old_Nit, dtype=nb.float64)
    r = np.sqrt((X ** 2) + (Y ** 2))
    mean_q = np.zeros(Nit + old_Nit, dtype=nb.float64)
    moment_2 = np.zeros(Nit + old_Nit, dtype=nb.float64)
    moment_3 = np.zeros(Nit + old_Nit, dtype=nb.float64)
    moment_4 = np.zeros(Nit + old_Nit, dtype=nb.float64)

    # place in old values
    residual[:old_Nit] = residual0
    q_max[:old_Nit] = q_max0
    q_min[:old_Nit] = q_min0
    q_tot[:old_Nit] = q_tot0
    mean_q[:old_Nit] = mean_q0
    moment_2[:old_Nit] = moment_20
    moment_3[:old_Nit] = moment_30
    moment_4[:old_Nit] = moment_40

    for i in range(Nit):
        j = i + old_Nit
        q = Gauss_Seidel_point_by_point_relax(q, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

        # recalculate coefficients
        a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_charge(a_E, a_W, a_N, a_S, S_p, source, neighbours,
                                  N, u, alpha, q, X, Y, dx, Area, umag, a,
                                  tau_p, d_p, w, V, rho_p, limit)
        a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

        residual[j] = calculate_scaled_residual(q, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)

        # calculate stats
        q_max[j], q_min[j], q_tot[j] = calculate_alpha_quantities(q)
        mean_q[j], moment_2[j], moment_3[j], moment_4[j] = r_moments(r, q, q_tot[j])

        # if converged
        if residual[j] < res_criterion:
            # check if we did less iterations than the maximum
            if j < (Nit + old_Nit - 1):
                # get rid off padding zeroes on the arrays
                residual = residual[:j + 1]
                q_max = q_max[:j + 1]
                q_min = q_min[:j + 1]
                q_tot = q_tot[:j + 1]
                mean_q = mean_q[:j + 1]
                moment_2 = moment_2[:j + 1]
                moment_3 = moment_3[:j + 1]
                moment_4 = moment_4[:j + 1]

            break

    if residual[j] > res_criterion:
        print("Convergence criterion not reached after " + str(i + 1) + " additional iterations.")



    return (q, residual, q_max, q_min, q_tot, mean_q, moment_2, moment_3,
            moment_4, a_P, a_E, a_W, a_N, a_S, source, S_p)


@njit(nogil=True, parallel=True)
def it_electric_resi(N, umag, a, X, Y, dx, Area, V, alpha, q, phi, u, v, w,
                   neighbours, a_P, a_E, a_W, a_N, a_S, source, S_p, rho_f,
                   rho_p, tau_p, d_p, Nit, relax=1.0, limit=0.0,
                   res_criterion=1e-15):
    dt = None
    residual = np.zeros(Nit, dtype=nb.float64)
    phi_max = np.zeros(Nit, dtype=nb.float64)
    phi_min = np.zeros(Nit, dtype=nb.float64)
    phi_tot = np.zeros(Nit, dtype=nb.float64)
    r = np.sqrt((X ** 2) + (Y ** 2))
    mean_phi = np.zeros(Nit, dtype=nb.float64)
    moment_2 = np.zeros(Nit, dtype=nb.float64)
    moment_3 = np.zeros(Nit, dtype=nb.float64)
    moment_4 = np.zeros(Nit, dtype=nb.float64)

    for i in range(Nit):
        phi = Gauss_Seidel_point_by_point_relax(phi, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

        # recalculate coefficients
        a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_electric(a_E, a_W, a_N, a_S, S_p, source, neighbours,
                                  N, u, alpha, q, X, Y, dx, Area, umag, a,
                                  tau_p, d_p, w, V, rho_p, limit)
        a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

        residual[i] = calculate_scaled_residual(phi, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)

        # calculate stats
        phi_max[i], phi_min[i], phi_tot[i] = calculate_alpha_quantities(phi)
        mean_phi[i], moment_2[i], moment_3[i], moment_4[i] = r_moments(r, phi, phi_tot[i])

        # if converged
        if residual[i] < res_criterion:
            # check if we did less iterations than the maximum
            if i < (Nit - 1):
                # get rid off padding zeroes on the arrays
                residual = residual[:i + 1]
                phi_max = phi_max[:i + 1]
                phi_min = phi_min[:i + 1]
                phi_tot = phi_tot[:i + 1]
                mean_phi = mean_phi[:i + 1]
                moment_2 = moment_2[:i + 1]
                moment_3 = moment_3[:i + 1]
                moment_4 = moment_4[:i + 1]

            break

    if residual[i] > res_criterion:
        print("Convergence criterion not reached after " + str(i + 1) + " iterations.")



    return (phi, residual, phi_max, phi_min, phi_tot, mean_phi, moment_2, moment_3,
            moment_4, a_P, a_E, a_W, a_N, a_S, source, S_p)


@njit(nogil=True, parallel=True)
def time_iterate(N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
            a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p, Nit, NT, dt,
            relax=1.0, limit=0.0, flux_func=volumetric_flow_exact, source_func=source_uniform_grid_implicit,
            diff_func=diffusion_coefficient_none, conv_func=upwind,
            time_func=first_order_backward_euler):
    Ntot = Nit * NT
    residual = np.zeros(Ntot, dtype=nb.float64)
    alpha_max = np.zeros(Ntot, dtype=nb.float64)
    alpha_min = np.zeros(Ntot, dtype=nb.float64)
    alpha_tot = np.zeros(Ntot, dtype=nb.float64)

    for t in range(NT):
        alpha_old = alpha.copy()
        for it in range(Nit):
            i = (t * Nit) + it
            alpha = Gauss_Seidel_point_by_point_relax(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

            # recalculate coefficients
            a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_unsteady(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha, X, Y,
                                                                     dx, Area, umag, a, tau_p, d_p, w, flux_func, source_func, diff_func, conv_func,
                                                                     V, dt, alpha_old, limit)
            a_P = calculate_a_P_unsteady(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)
    
            residual[i] = calculate_scaled_residual(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)
    
            # calculate stats
            alpha_max[i], alpha_min[i], alpha_tot[i] = calculate_alpha_quantities(alpha)

    return alpha, residual, alpha_max, alpha_min, alpha_tot, a_P, a_E, a_W, a_N, a_S, source, S_p


@njit(nogil=True, parallel=True)
def time_resi(N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P,
              a_E, a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p, Nit,
              NT, dt, relax=1.0, limit=0.0, flux_func=volumetric_flow_exact,
              source_func=source_uniform_grid_implicit,
              diff_func=diffusion_coefficient_none, conv_func=upwind,
              time_func=first_order_backward_euler,
              a_nb_func=calculate_a_neighbours, a_P_func=calculate_a_P,
              res_criterion=1e-15):
    Ntot = Nit * NT
    residual = np.zeros(Ntot, dtype=nb.float64)
    alpha_max = np.zeros(Ntot, dtype=nb.float64)
    alpha_min = np.zeros(Ntot, dtype=nb.float64)
    alpha_tot = np.zeros(Ntot, dtype=nb.float64)
    r = np.sqrt((X ** 2) + (Y ** 2))
    mean_alpha = np.zeros(Ntot, dtype=nb.float64)
    moment_2 = np.zeros(Ntot, dtype=nb.float64)
    moment_3 = np.zeros(Ntot, dtype=nb.float64)
    moment_4 = np.zeros(Ntot, dtype=nb.float64)
    time_step_ends = np.zeros(NT, dtype=nb.int64)

    i = 0

    for t in range(NT):
        alpha_old = alpha.copy()
        if t > 0:
            time_step_ends[t] = i
        for it in range(Nit):
            alpha = Gauss_Seidel_point_by_point_relax(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours, relax)

            # recalculate coefficients
            a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_unsteady(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha, X, Y,
                                                                     dx, Area, umag, a, tau_p, d_p, w, flux_func, source_func, diff_func, conv_func,
                                                                     V, dt, alpha_old, limit)
            a_P = calculate_a_P_unsteady(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)
    
            residual[i] = calculate_scaled_residual(alpha, a_P, a_E, a_W, a_N, a_S, source, N, neighbours)
    
            # calculate stats
            alpha_max[i], alpha_min[i], alpha_tot[i] = calculate_alpha_quantities(alpha)
            mean_alpha[i], moment_2[i], moment_3[i], moment_4[i] = r_moments(r, alpha, alpha_tot[i])
            i += 1

    time_step_ends[t] = i - 1

    if residual[i] > res_criterion:
        print("Convergence criterion not reached after " + str(i + 1) + " iterations.")

    # check if we did less iterations than the maximum
    if i < (Ntot - 1):
        # get rid off padding zeroes on the arrays
        residual = residual[:i + 1]
        alpha_max = alpha_max[:i + 1]
        alpha_min = alpha_min[:i + 1]
        alpha_tot = alpha_tot[:i + 1]
        mean_alpha = mean_alpha[:i + 1]
        moment_2 = moment_2[:i + 1]
        moment_3 = moment_3[:i + 1]
        moment_4 = moment_4[:i + 1]
        time_step_ends = time_step_ends[:t]


    return (alpha, residual, alpha_max, alpha_min, alpha_tot, mean_alpha,
            moment_2, moment_3, moment_4, a_P, a_E, a_W, a_N, a_S, source, S_p,
            time_step_ends)


@njit(nogil=True, parallel=True)
def calculate_alpha_quantities(alpha):
    alpha_max = np.max(alpha)
    alpha_min = np.min(alpha)
    alpha_tot = np.sum(alpha)

    return alpha_max, alpha_min, alpha_tot


@njit(nogil=True, parallel=True)
def calculate_all_alpha_quantities(alpha_all):
    M = alpha_all.shape[0]

    alpha_max = np.zeros(M, dtype=nb.float64)
    alpha_min = np.zeros_like(alpha_max)
    alpha_tot = np.zeros_like(alpha_max)


    for i in prange(M):
        alpha_max[i], alpha_min[i], alpha_tot[i] = calculate_alpha_quantities(alpha_all[i])

    return alpha_max, alpha_min, alpha_tot


@njit(nogil=True, parallel=True)
def r_moments(r, alpha, alpha_tot):
    """
    Mean and central moments of the radial volume fraction distribution.

    Parameters
    ----------
    R : array of float
        2D meshgrid of radial co-ordinates.
    alpha : array of float
        volume fraction array.
    alpha_tot : float
        total volume fraction of alpha array.

    Returns
    -------
    mean : float
        mean alpha radial position.
    moment_2 : float
        second central moment.
    moment_3 : float
        third central moment.
    moment_4 : float
        fourth central moment.

    """
    N = len(alpha)
    top_1 = 0.0
    moment_2 = 0.0
    moment_3 = 0.0
    moment_4 = 0.0

    # calculate mean
    for i in prange(N):
        for j in prange(N):
            top_1 += alpha[i, j] * r[i, j]

    mean = top_1 / alpha_tot

    # now calculate central moments
    for i in prange(N):
        for j in prange(N):
            moment_2 += alpha[i, j] * ((r[i, j] - mean) ** 2)
            moment_3 += alpha[i, j] * ((r[i, j] - mean) ** 3)
            moment_4 += alpha[i, j] * ((r[i, j] - mean) ** 4)

    return mean, moment_2, moment_3, moment_4


@njit(nogil=True, parallel=True)
def potential_gradient(phi, dx, N, E, neighbours):
    inv_dx = 1.0 / (2. * dx)
    for i in prange(N):
        for j in prange(N):
            west, east, north, south = neighbours[i, j]
            # central differencing
            E[i, j, 0] = (phi[east, j] - phi[west, j]) * inv_dx
            E[i, j, 1] = (phi[i, north] - phi[i, south]) * inv_dx


def electric_field_from_potential(phi, dx, N, neighbours):
    E = np.zeros((N, N, 2), dtype=np.float64)
    potential_gradient(phi, dx, N, E, neighbours)

    return E



def plot(alpha_max, alpha_min, alpha_tot, residual):
    fig, ax = plt.subplots(2, 2)
    ax1 = plt.subplot(221)
    ax1.semilogy(residual)
    ax1.set_title("Normalised Residual")
    
    ax2 = plt.subplot(222)
    ax2.plot(alpha_max)
    ax2.set_title(r"$\alpha_{max}$")
    
    ax3 = plt.subplot(223)
    ax3.semilogy(alpha_min)
    ax3.set_title(r"$\alpha_{min}$")
    
    ax4 = plt.subplot(224)
    ax4.plot(alpha_tot)
    ax4.set_title(r"$\alpha_{tot}$")
    
    plt.tight_layout()


def plot_moments(mean_alpha, moment_2, moment_3, moment_4, residual, l):
    fig, ax = plt.subplots(3, 2)
    ax1 = plt.subplot(321)
    ax1.plot(mean_alpha / l)
    ax1.set_xlabel("Iteration")
    ax1.set_ylabel(r"$\frac{<r>}{l}$")
    
    ax2 = plt.subplot(322)
    ax2.plot(np.sqrt(moment_2) / l)
    ax2.set_xlabel("Iteration")
    ax2.set_ylabel(r"$\frac{M_2^{\frac{1}{2}}}{l}$")
    
    ax3 = plt.subplot(323)
    ax3.plot(moment_3 / (moment_2 ** 1.5))
    ax3.set_xlabel("Iteration")
    ax3.set_ylabel(r"$\frac{M_3}{M_2^{\frac{1}{2}}}$")
    
    ax4 = plt.subplot(324)
    ax4.plot(moment_4 / (moment_2 ** 2))
    ax4.set_xlabel("Iteration")
    ax4.set_ylabel(r"$\frac{M_4}{M_2^2}$")
    
    ax5 = plt.subplot(325)
    ax5.semilogy(residual)
    ax5.set_xlabel("Iteration")
    ax5.set_ylabel(r"Normalised Residual")
    
    plt.tight_layout()


def plot_moments_unsteady(mean_alpha, moment_2, moment_3, moment_4, residual,
                          l, time_step_ends, dt, tau_p):
    t = np.linspace(dt, len(time_step_ends) * dt, len(time_step_ends))
    t_norm = t / tau_p
    fig, ax = plt.subplots(3, 2)
    ax1 = plt.subplot(321)
    ax1.plot(t_norm, mean_alpha[time_step_ends] / l)
    ax1.set_xlabel(r"$\frac{t}{\tau_p}$")
    ax1.set_ylabel(r"$\frac{<r>}{l}$")
    
    ax2 = plt.subplot(322)
    ax2.plot(t_norm, np.sqrt(moment_2[time_step_ends]) / l)
    ax2.set_xlabel(r"$\frac{t}{\tau_p}$")
    ax2.set_ylabel(r"$\frac{M_2^{\frac{1}{2}}}{l}$")
    
    ax3 = plt.subplot(323)
    ax3.plot(t_norm, moment_3[time_step_ends] /
             (moment_2[time_step_ends] ** 1.5))
    ax3.set_xlabel(r"$\frac{t}{\tau_p}$")
    ax3.set_ylabel(r"$\frac{M_3}{M_2^{\frac{1}{2}}}$")
    
    ax4 = plt.subplot(324)
    ax4.plot(t_norm, moment_4[time_step_ends] /
             (moment_2[time_step_ends] ** 2))
    ax4.set_xlabel(r"$\frac{t}{\tau_p}$")
    ax4.set_ylabel(r"$\frac{M_4}{M_2^2}$")
    
    ax5 = plt.subplot(325)
    ax5.semilogy(t_norm, residual[time_step_ends])
    ax5.set_xlabel(r"$\frac{t}{\tau_p}$")
    ax5.set_ylabel(r"Normalised Residual")
    
    plt.tight_layout()


def plot_alpha(X, Y, alpha, l):
    plt.figure()
    plt.contourf(X / (2 * l), Y / (2 * l), alpha / 5e-7)
    plt.xlabel(r"$\frac{x}{2l}$", fontsize=16)
    plt.ylabel(r"$\frac{y}{2l}$", fontsize=16)
    plt.colorbar()
    plt.tight_layout()
    
    
def plot_diffusion_ratios():
    rho_p = 2650.0
    a = 35520.91933986964
    umag = 0.29283613
    alpha = 5e-7
    l = np.pi / a  # m
    tau_f = l / umag
    Stk = np.logspace(-5, 1, num=10000)
    tau_p = tau_f * Stk
    mu_f = 1.7894e-5
    d_p = np.sqrt(18 * mu_f * tau_p / rho_p)
    dx = l / 64.0

    # maximum
    x = np.pi / (2 * a)
    y = np.pi / (2 * a)
    D_lam_max = np.zeros_like(Stk)
    D_t_max = np.zeros_like(Stk)
    D_coll_max = np.zeros_like(Stk)
    for i in range(len(Stk)):
        D_lam_max[i] = diffusion_coefficient_brownian(x, y, alpha, umag, a,
                                                      tau_p[i], d_p[i], dx)
        D_t_max[i] = diffusion_coefficient_turb(x, y, alpha, umag, a, tau_p[i],
                                                d_p[i], dx)
        D_coll_max[i] = diffusion_coefficient_collisions_no_correction(x, y, alpha, umag, a,
                                                         tau_p[i], d_p[i], dx)
    D_ratio_max_t_lam = D_t_max / D_lam_max
    D_ratio_max_coll_lam = D_coll_max / D_lam_max
    D_ratio_max_coll_t = D_coll_max / D_t_max
    

    # minimum
    x = 0
    y = 0
    D_lam_min = np.zeros_like(Stk)
    D_t_min = np.zeros_like(Stk)
    D_coll_min = np.zeros_like(Stk)
    for i in range(len(Stk)):
        D_lam_min[i] = diffusion_coefficient_brownian(x, y, alpha, umag, a,
                                                      tau_p[i], d_p[i], dx)
        D_t_min[i] = diffusion_coefficient_turb(x, y, alpha, umag, a, tau_p[i],
                                                d_p[i], dx)
        D_coll_min[i] = diffusion_coefficient_collisions_no_correction(x, y, alpha, umag, a,
                                                         tau_p[i], d_p[i], dx)
    D_ratio_min_t_lam = D_t_min / D_lam_min
    D_ratio_min_coll_lam = D_coll_min / D_lam_min
    D_ratio_min_coll_t = D_coll_min / D_t_min

    plt.figure()
    plt.loglog(Stk, D_ratio_max_t_lam, label='Maximum')
    plt.loglog(Stk, D_ratio_min_t_lam, label='Minimum')
    plt.xlabel(r"$Stk$", fontsize=16)
    plt.ylabel(r"$\frac{D_{turb}}{D_{lam}}$", fontsize=16)
    plt.legend()
    plt.tight_layout()
    plt.savefig("diffusion_ratio_vs_stk_corrected.eps")

    plt.figure()
    plt.loglog(Stk, D_ratio_max_coll_lam, label='Maximum')
    plt.loglog(Stk, D_ratio_min_coll_lam, label='Minimum')
    plt.xlabel(r"$Stk$", fontsize=16)
    plt.ylabel(r"$\frac{D_{coll}}{D_{lam}}$", fontsize=16)
    plt.legend()
    plt.tight_layout()
    plt.savefig("diffusion_ratio_coll_lam_vs_stk_corrected.eps")

    plt.figure()
    plt.loglog(Stk, D_ratio_max_coll_t, label='Maximum')
    plt.loglog(Stk, D_ratio_min_coll_t, label='Minimum')
    plt.xlabel(r"$Stk$", fontsize=16)
    plt.ylabel(r"$\frac{D_{coll}}{D_{turb}}$", fontsize=16)
    plt.legend()
    plt.tight_layout()
    plt.savefig("diffusion_ratio_coll_turb_vs_stk_no_correction_corrected.eps")


def initialise(N=128, Stk=1e-3, alpha_0=5e-7, rho_f=1.0, rho_p=2650.0,
               umag=1.543, a=84765, flux_func=volumetric_flow_exact,
               source_func=source_uniform_grid_implicit,
               diff_func=diffusion_coefficient_none,
               conv_func=upwind):
    l = np.pi / a  # m
    tau_f = l / umag
    tau_p = tau_f * Stk
    mu_f = 1.7894e-5
    d_p = np.sqrt(18 * mu_f * tau_p / rho_p)
    dt = None
    alpha_old = None

    (X, Y, XF, YF, dx, Area, V) = create_uniform_grid(N, l)

    alpha = np.ones((N, N), dtype=np.float64) * alpha_0

    u, v, w, neighbours = velocities_and_neighbours(X, Y, N, umag, a, tau_p)

    # construct A matrices
    a_P = np.zeros((N, N), dtype=np.float64)  # main diagonal
    a_E = np.zeros((N, N), dtype=np.float64)
    a_W = np.zeros((N, N), dtype=np.float64)
    a_N = np.zeros((N, N), dtype=np.float64)
    a_S = np.zeros((N, N), dtype=np.float64)
    source = np.zeros((N, N), dtype=np.float64)
    S_p = np.zeros((N, N), dtype=np.float64)

    # get neighhbour coefficients and source
    a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha,
                                                             X, Y, dx, Area, umag, a, tau_p, d_p,
                                                             w, flux_func, source_func, diff_func, conv_func, V, dt, alpha_old, limit=0.0)

    # set a_P
    a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

    # perform basic checks
    check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p)

    return (N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
            a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p)


def initialise_mee(N=128, Stk=1e-3, alpha_0=5e-7, rho_f=1.0, rho_p=2650.0,
               umag=1.543, a=84765, flux_func=volumetric_flow_exact,
               source_func=source_uniform_grid_implicit,
               diff_func=diffusion_coefficient_none,
               conv_func=upwind):
    l = np.pi / a  # m
    tau_f = l / umag
    tau_p = tau_f * Stk
    mu_f = 1.7894e-5
    d_p = np.sqrt(18 * mu_f * tau_p / rho_p)
    dt = None
    alpha_old = None

    (X, Y, XF, YF, dx, Area, V) = create_uniform_grid(N, l)

    alpha = np.ones((N, N), dtype=np.float64) * alpha_0

    u, v, w, neighbours = velocities_and_neighbours_mee(X, Y, N, umag, a, tau_p)

    # construct A matrices
    a_P = np.zeros((N, N), dtype=np.float64)  # main diagonal
    a_E = np.zeros((N, N), dtype=np.float64)
    a_W = np.zeros((N, N), dtype=np.float64)
    a_N = np.zeros((N, N), dtype=np.float64)
    a_S = np.zeros((N, N), dtype=np.float64)
    source = np.zeros((N, N), dtype=np.float64)
    S_p = np.zeros((N, N), dtype=np.float64)

    # get neighhbour coefficients and source
    a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha,
                                                             X, Y, dx, Area, umag, a, tau_p, d_p,
                                                             w, flux_func, source_func, diff_func, conv_func, V, dt, alpha_old, limit=0.0)

    # set a_P
    a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

    # perform basic checks
    check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p)

    return (N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
            a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p)


def initialise_unsteady(N=128, Stk=1e-3, alpha_0=5e-7, rho_f=1.0, rho_p=2650.0,
                        umag=1.543, a=84765, flux_func=volumetric_flow_exact,
                        source_func=source_uniform_grid_implicit,
                        diff_func=diffusion_coefficient_none,
                        time_func=first_order_backward_euler, dt=1e-8,
                        conv_func=upwind):
    l = np.pi / a  # m
    tau_f = l / umag
    tau_p = tau_f * Stk
    mu_f = 1.7894e-5
    d_p = np.sqrt(18 * mu_f * tau_p / rho_p)
    
    (X, Y, XF, YF, dx, Area, V) = create_uniform_grid(N, l)
    
    alpha = np.ones((N, N), dtype=np.float64) * alpha_0
    alpha_old = alpha.copy()
    
    u, v, w, neighbours = velocities_and_neighbours(X, Y, N, umag, a, tau_p)
    
    # construct A matrices
    a_P = np.zeros((N, N), dtype=np.float64)  # main diagonal
    a_E = np.zeros((N, N), dtype=np.float64)
    a_W = np.zeros((N, N), dtype=np.float64)
    a_N = np.zeros((N, N), dtype=np.float64)
    a_S = np.zeros((N, N), dtype=np.float64)
    source = np.zeros((N, N), dtype=np.float64)
    S_p = np.zeros((N, N), dtype=np.float64)
    
    # get neighhbour coefficients and source
    a_E, a_W, a_N, a_S, source, S_p = calculate_a_neighbours_unsteady(a_E, a_W, a_N, a_S, S_p, source, neighbours, N, u, alpha,
                                                             X, Y, dx, Area, umag, a, tau_p, d_p,
                                                             w, flux_func, source_func, diff_func, conv_func, V, dt, alpha_old, limit=0.0)

    # set a_P
    a_P = calculate_a_P_unsteady(a_P, a_E, a_W, a_N, a_S, S_p, N, V, dt)

    # perform basic checks
    check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p)
    
    CFL = umag * dt / dx

    print("CFL number = {}".format(CFL))

    return (N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
            a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p, dt)


def initialise_charge(N, umag, a, X, Y, dx, Area, V, alpha, u, v, w,
                      neighbours, rho_f, rho_p, tau_p, d_p, q_0=0.0):
    q = np.ones_like(alpha) * q_0
    a_P = np.zeros((N, N), dtype=np.float64)  # main diagonal
    a_E = np.zeros((N, N), dtype=np.float64)
    a_W = np.zeros((N, N), dtype=np.float64)
    a_N = np.zeros((N, N), dtype=np.float64)
    a_S = np.zeros((N, N), dtype=np.float64)
    source = np.zeros((N, N), dtype=np.float64)
    S_p = np.zeros((N, N), dtype=np.float64)

    (a_E, a_W, a_N, a_S, source,
     S_p) = calculate_a_neighbours_charge(a_E, a_W, a_N, a_S, S_p, source,
                                          neighbours, N, u, alpha, q, X, Y, dx,
                                          Area, umag, a, tau_p, d_p, w, V,
                                          rho_p, 0.0)

    a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, None)

    # perform basic checks
    check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p)

    return (q, a_P, a_E, a_W, a_N, a_S, source, S_p)


def initialise_electric(N, umag, a, X, Y, dx, Area, V, alpha, q, u, v, w,
                      neighbours, rho_f, rho_p, tau_p, d_p):
    phi = np.zeros_like(alpha)
    a_P = np.zeros((N, N), dtype=np.float64)  # main diagonal
    a_E = np.zeros((N, N), dtype=np.float64)
    a_W = np.zeros((N, N), dtype=np.float64)
    a_N = np.zeros((N, N), dtype=np.float64)
    a_S = np.zeros((N, N), dtype=np.float64)
    source = np.zeros((N, N), dtype=np.float64)
    S_p = np.zeros((N, N), dtype=np.float64)

    (a_E, a_W, a_N, a_S, source,
     S_p) = calculate_a_neighbours_electric(a_E, a_W, a_N, a_S, S_p, source,
                                          neighbours, N, u, alpha, q, X, Y, dx,
                                          Area, umag, a, tau_p, d_p, w, V,
                                          rho_p, 0.0)

    a_P = calculate_a_P(a_P, a_E, a_W, a_N, a_S, S_p, N, V, None)

    # perform basic checks
    check_coefficients(a_P, a_E, a_W, a_N, a_S, S_p)

    return (phi, a_P, a_E, a_W, a_N, a_S, source, S_p)


if  __name__ == '__main__':
# =============================================================================
#     (N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E,
#      a_W, a_N, a_S, source, S_p, rho_f, rho_p, tau_p, d_p) = initialise(N=128, Stk=1e-3, alpha_0=5e-7, rho_f=1.0, rho_p=2650.0,
#                umag=0.29283613, a=35520.919, flux_func=volumetric_flow_exact,
#                source_func=source_uniform_grid_implicit,
#                diff_func=diffusion_coefficient_none,
#                conv_func=upwind)
# =============================================================================
# =============================================================================
#     (N, umag, a, X, Y, dx, Area, V, alpha, u, v, w, neighbours, a_P, a_E, a_W, a_N,
#      a_S, source, S_p, rho_f, rho_p, tau_p, d_p, dt) = initialise_unsteady(N=128, Stk=1e-3, alpha_0=5e-7,
#                                              rho_f=1.0, rho_p=2650.0,
#                                              umag=0.29283613,
#                                              a=35520.919,
#                                              flux_func=volumetric_flow_exact,
#                                              source_func=source_uniform_grid_implicit,
#                                              diff_func=diffusion_coefficient_collisions,
#                                              time_func=first_order_backward_euler,
#                                              dt=1e-4)
# =============================================================================
    pass
