#!/usr/bin/env python
# -*- coding: utf-8 -*-

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import vtk
from matplotlib import patches as patches
import matplotlib.colors as mcolors

#plt.style.use('mystyle')
plt.style.use('classic')

ls = [(0,()),
     (0,(1.1,1.1)),
     (0,(2.8,1.1)),
     (0,(2.8,1.1,1.1,1.1)),
     (0,(3,1,1,1,1,1)),
     (0,(3,1,3,1,1,1,1,1))]

def read_txt(file, skiprows=0, delimiter=','):
	lines = np.loadtxt(file, skiprows=skiprows, delimiter=delimiter, unpack=True)
	return lines


def read_vti(file):
    reader = vtk.vtkXMLPImageDataReader()
    reader.SetFileName(file)
    reader.Update()
    data = reader.GetOutput()
    pointData = data.GetPointData()

    sh = data.GetDimensions()[::-1]
    ndims = len(sh)

    # get vector field
    v = np.array(pointData.GetVectors("Velocity")).reshape(sh + (ndims,))
    vec = []
    for d in range(ndims):
        a = v[..., d]
        vec.append(a)
    vec = np.array(vec)
    print('vec', vec.shape)
    # get scalar field
    sca = np.array(pointData.GetScalars('Pressure')).reshape(sh + (1,))

    # Generate grid
    (xmin, xmax, ymin, ymax, zmin, zmax) = data.GetBounds()
    grid3D = np.mgrid[xmin:xmax + 1, ymin:ymax + 1, zmin:zmax + 1]

    return np.transpose(vec, (0,3,2,1)), np.transpose(sca, (0,3,2,1)), grid3D


def read_vtr(fname):
	reader = vtk.vtkXMLPRectilinearGridReader()
	reader.SetFileName(fname)
	reader.Update()
	data = reader.GetOutput()
	pointData = data.GetPointData()

	sh = data.GetDimensions()[::-1]
	ndims = len(sh)

	# get vector field
	v = np.array(pointData.GetVectors("Velocity")).reshape(sh + (ndims,))
	vec = []
	for d in range(ndims):
		a = v[..., d]
		vec.append(a)
	vec = np.array(vec)
	print('vec', vec.shape)
	# get scalar field
	sca = np.array(pointData.GetScalars('Pressure')).reshape(sh + (1,))

	# get grid
	x = np.array(data.GetXCoordinates())
	y = np.array(data.GetYCoordinates())
	z = np.array(data.GetZCoordinates())
	print(x.shape, y.shape, z.shape)

	return np.transpose(vec, (0,3,2,1)), np.transpose(sca, (0,3,2,1)), np.array([x, y, z])


def get_circulation(vort, box, d=3, slice=1, debug=True):
    xmin,ymin,xmax,ymax = box
    length=xmax-xmin; width=ymax-ymin
    area=length*width
    omega = np.sum(vort[int(d-1),xmin:xmax,ymin:ymax,slice], axis=(-2,-1))
    if debug:
        fig,ax=plt.subplots(1)
        p=ax.contourf(vort[2,:,:,slice].T,cmap='RdBu',levels=51)
        rect=patches.Rectangle((xmin,ymin),length,width,linewidth=1,edgecolor='k',facecolor='none')
        ax.add_patch(rect)
        plt.colorbar(p)
        plt.show()
    return abs(omega),area

if __name__ == "__main__":

    # numerical parameters
    nu = 64/8.5e3
    box = np.array([70,175,72,177])

    # get data
    vort,lam2,grid3D = read_vti('vortF.1.pvti')

    # define your scale, with white at zero
    vmin = -0.3
    vmax = 0.3
    norm = mcolors.TwoSlopeNorm(vmin=vmin, vcenter=0, vmax=vmax)
    omega=0
    fig,ax=plt.subplots(1)
    for t in range(0,43):
    	omega =omega+ vort[2,:,:,t].T
    p=ax.pcolor(np.round(omega/(t+1),4),cmap='bwr',vmin=vmin,vmax=vmax,norm=norm)
    for s in range(0,6):
     Star=0
     _box = box+np.array([-4,-4,4,4])*s
     xmin,ymin,xmax,ymax = _box
     length=xmax-xmin; width=ymax-ymin
     rect=patches.Rectangle((xmin,ymin),length,width,linewidth=0.5,
               	               edgecolor='k',linestyle=ls[s],facecolor='none')
     ax.add_patch(rect)
     for z in range(0,43):
      Tau, area = get_circulation(vort,_box,slice=z,debug=False)
      Star = Star+Tau
     Aver=Star/(z+1)
     print('BDIM Circulation-based Re : %.2f' % (Aver/nu))
    plt.ylim(0,160); plt.yticks([0,32,64,96,128,160],[0.00,0.25,0.50,0.75,1.00,1.25])
    plt.xlim(0,256); plt.xticks([0,32,64,96,128,160,192,224,256],[-0.50,-0.25,0.00,0.25,0.50,0.75,1.00,1.25,1.50])
    plt.xlabel(r'$x/R$'); plt.ylabel(r'$y/R$')
    plt.show()