# -*- coding: utf-8 -*-
"""
Created on Mon Apr 27 16:52:08 2020

@author: Daniel Powell
"""

import numpy as np
import sympy as sp
from sympy.vector import CoordSys3D, Vector
from sympy.physics.vector import ReferenceFrame, divergence, curl, gradient
from sympy import init_printing
from matplotlib import pyplot as plt
from cunningham_slip import C_c_factor


init_printing()
N = CoordSys3D('N')
R = ReferenceFrame('x')

umag, a, x, y, tau_p = sp.symbols(r"u_{mag} a x y \tau_p")

u = umag * (((sp.cos(a * R[0]) * sp.sin(a * R[1])) * R.x) -
            ((sp.sin(a * R[0]) * sp.cos(a * R[1])) * R.y))

u_x = umag * sp.cos(a * R[0]) * sp.sin(a * R[1])
u_y = -umag * sp.sin(a * R[0]) * sp.cos(a * R[1])

# a_i
# nabla u
nabla_u = sp.Matrix([[0, 0], [0, 0]])
du_xdxj = gradient(u_x, R)
nabla_u[0, 0] = du_xdxj.dot(R.x) # du/dx
nabla_u[0, 1] = du_xdxj.dot(R.y) # du/dy
du_ydxj = gradient(u_y, R)
nabla_u[1, 0] = du_ydxj.dot(R.x) # du/dx
nabla_u[1, 1] = du_ydxj.dot(R.y) # du/dy

accel_x = ((u_x * nabla_u[0, 0]) + (u_y * nabla_u[0, 1])).simplify()
accel_y = ((u_x * nabla_u[1, 0]) + (u_y * nabla_u[1, 1])).simplify()
accel = (accel_x * R.x) + (accel_y * R.y)

v = u - (tau_p * accel)
v_matrix = sp.Matrix([v.dot(R.x), v.dot(R.y)])

# fluid strain rate tensor
S_fluid = sp.Matrix([[nabla_u[0, 0], 0.5 * (nabla_u[0, 1] + nabla_u[1, 0])],
                     [0.5 * (nabla_u[0, 1] + nabla_u[1, 0]), nabla_u[1, 1]]])

# particle
nabla_v = sp.Matrix([[0, 0], [0, 0]])
nabla_v[0, 0] = gradient(v_matrix[0], R).dot(R.x)
nabla_v[0, 1] = gradient(v_matrix[0], R).dot(R.y)
nabla_v[1, 0] = gradient(v_matrix[1], R).dot(R.x)
nabla_v[1, 1] = gradient(v_matrix[1], R).dot(R.y)


# particle strain rate tensor
S_part = sp.Matrix([[nabla_v[0, 0], 0.5 * (nabla_v[0, 1] + nabla_v[1, 0])],
                    [0.5 * (nabla_v[0, 1] + nabla_v[1, 0]), nabla_v[1, 1]]])



# =============================================================================
# a = 4813.34
# L2 = 2 * np.pi / a
# l = np.pi / a
# umax = 100
# mu_f = 1.7894e-5
# D = 3.442516e-11
# nu_f = mu_f / 1.225
# 
# # use N = 1000 for pics
# N = 1000
# x = np.linspace(-l, l, N)
# y = np.linspace(-l, l, N)
# X, Y = np.meshgrid(x, y)
# d_p = 1e-7
# rho_p = 2650
# mu_f = 1.7894e-5
# C_c = C_c_factor(1e-7)
# tau_p_orig = (rho_p * (d_p ** 2) * C_c) / (18 * mu_f)
# Stk = 0.01
# tau_p = tau_p_orig * Stk
# 
# fig, ax = plt.subplots(1)
# 
# div_old = 2 * a ** 2 * umax ** 2 * (np.sin(a * X) ** 2 - np.cos(a * Y) ** 2)
# div = 2 * a ** 2 * umax ** 2 * ((np.sin(a * X) ** 2 * np.sin(a * Y) ** 2) - (np.cos(a * X) ** 2 * np.cos(a * Y) ** 2))
# 
# im = ax.pcolor(X / L2, Y / L2, div, cmap='viridis')
# cbar = fig.colorbar(im, ax=ax)
# cbar.ax.set_ylabel(r"$\frac{\partial}{\partial x_i} \left( u_j \frac{\partial u_i}{\partial x_j} \right)$", fontsize=16)
# ax.set_title(r"$Stk = $" + str(Stk), fontsize=12)
# plt.xlabel(r"$\frac{x}{2 \pi a}$", fontsize=16)
# plt.ylabel(r"$\frac{y}{2 \pi a}$", fontsize=16)
# plt.tight_layout()
# # plt.savefig("stk001_div_field_not_simple.png", dpi=600)
# plt.show()
# =============================================================================
