import numpy as np
import Func
from scipy.interpolate import griddata
from scipy.interpolate import CloughTocher2DInterpolator
from matplotlib import use
use('Agg')
#   *******************************************
#   ***Setup constants and initial variables***
#   *******************************************

#Restart?
Restart = False

#Constants
U_inf = 0.028 #Free stream velocity
Nit = 100000 #Maximum itteration number
Kp = 1e-4 #Proportional error
D = 1 #Normalisation for Kp

#Variable initialisation
Exp_x, Exp_y = simfunc.TxtRead('Points.txt',2)
Exp_U, Exp_V = simfunc.TxtRead('Velocity.txt',2)
Sim_x, Sim_y = simfunc.TxtRead('SimPoints.txt',2)

Sim_x = np.array(Sim_x)
Sim_y = np.array(Sim_y)
Exp_x = np.array(Exp_x)
Exp_y = np.array(Exp_y)
Exp_U = np.array(Exp_U)
Exp_V = np.array(Exp_V)

#Set restart vars
if Restart == True:
    #Load in previous case
    ItterationCount = int(np.loadtxt('ItterationCount.txt',dtype=int))
    FinalItterationList = np.loadtxt('FinalItterationList.txt')
    L1_mean_list = np.loadtxt('L1_mean_list.txt')
    f = np.loadtxt('f.txt')
    fExp = np.loadtxt('fExp.txt')
    
else:
    ItterationCount = int(0) #Count of itterations
    L1_mean_list = np.array([]) #List for L1 values
    FinalItterationList = np.array([]) #List for number of Openfoam calls
    f = np.zeros((2,len(Sim_x))) #Sim forcing vectors
    fExp = np.zeros((2,len(Exp_x))) #Exp forcing vectors


#   *****************
#   ***Run DA loop***
#   *****************
for it in range(1,Nit):
    
    #   *********************************
    #   ***Run and evaluate simulation***
    #   *********************************
    
    #Run simulation
    Sim_U, Sim_V, OFit = simfunc.RunSim(f, 'OfFiles')

    #Write final itteration number
    FinalItterationList = np.append(FinalItterationList,OFit)
    np.savetxt('FinalItterationList.txt',FinalItterationList)
    
    #Linearly interpolate simulation results to experimental domain with nearest as extrapolation
    Sim_U_Lin = griddata(list(zip(Sim_x, Sim_y)), Sim_U, list(zip(Exp_x, Exp_y)), method='linear')
    Sim_U_near = griddata(list(zip(Sim_x, Sim_y)), Sim_U, list(zip(Exp_x, Exp_y)), method='nearest')
    Sim_V_Lin = griddata(list(zip(Sim_x, Sim_y)), Sim_V, list(zip(Exp_x, Exp_y)), method='linear')
    Sim_V_near = griddata(list(zip(Sim_x, Sim_y)), Sim_V, list(zip(Exp_x, Exp_y)), method='nearest')
    nanMaskU = np.isnan(Sim_U_Lin)
    nanMaskV = np.isnan(Sim_V_Lin)
    Sim_U_Lin[nanMaskU] = Sim_U_near[nanMaskU]
    Sim_V_Lin[nanMaskV] = Sim_V_near[nanMaskV]
    
    #Find L1 error on experimental Masked domain
    L1 = (np.absolute(Sim_U_Lin-Exp_U)+np.absolute(Sim_V_Lin-Exp_V))/U_inf
    L1_m = np.mean(L1)
    L1_mean_list = np.append(L1_mean_list,L1_m)
    #Write error to file
    np.savetxt('L1_mean_list.txt',L1_mean_list)
    
        
    #   **************************************
    #   ***Regularise and interpolate error***
    #   **************************************
    
    #Find forcing term in x and y
    fExp[0,:] += Kp/D*(Sim_U_Lin-Exp_U)
    fExp[1,:] += Kp/D*(Sim_V_Lin-Exp_V)
    #Write forcing term
    np.savetxt('fExp.txt',fExp)
    
    #Interp forcing to Sim Domain
    EFU = CloughTocher2DInterpolator(list(zip(Exp_x, Exp_y)),fExp[0,:],fill_value=0.0)
    EFV = CloughTocher2DInterpolator(list(zip(Exp_x, Exp_y)),fExp[1,:],fill_value=0.0)
    f[0,:] = EFU(Sim_x,Sim_y)
    f[1,:] = EFV(Sim_x,Sim_y)
    #Write forcing term
    np.savetxt('f.txt',f)
    
    #Update the itteration count
    ItterationCount += int(1)
    #Write itteration count
    with open('ItterationCount.txt','w') as ITf:
        ITf.write('%d' % ItterationCount)
    
    #   *************************
    #   ***Clean and Plot Data***
    #   *************************
    
    #Clean openfoam run, if multiple of 5 save and plot files
    if (ItterationCount) % 5 == 0 or ItterationCount == 1:
        #simfunc.PlotData('Plot', ItterationCount,Exp_x,Exp_y,L1,fExp)
        simfunc.SaveFiles('OfFiles', 'Save', ItterationCount)
    else:
        simfunc.RemoveFiles('OfFiles')
    
