########################################################################
#Average Kernel Generation Script
#incoporates code adapted form the original gch libraries: https://github.com/andreanelli/GCH
#and the DirectionalConvexHull code in: https://github.com/scikit-learn-contrib/scikit-matter/blob/main/src/skmatter/sample_selection/_base.py 
#alongside unique code
#####################################################################
#IMPORTS SECTION
#####################################################################
import sys
import os
import time
from multiprocessing import Pool
from functools import partial
import argparse


import numpy as np
from pymatgen import io
from pymatgen.io import xyz
from pymatgen import symmetry
from pymatgen.symmetry import analyzer


import os
import ccdc
from ccdc import descriptors
import ase
import ase.io as aseio
from ase.spacegroup import crystal
from ase.symbols import Symbols
import rascal
from rascal.representations import SphericalInvariants as SOAP
from rascal.neighbourlist.structure_manager import (mask_center_atoms_by_id)
from zipfile import ZipFile


import itertools

import collections
import pandas as pd
import scipy.linalg as salg
from scipy.spatial import ConvexHull as chull


###################################################
#start timing code - for possible analysis purposes
#################################################
start_time=time.time()

#####################################################################################################################################################################################################
#SPACEGROUP/SETTINGS FINDER SECTION
########################################################################################################################################################################################################

#FUNCTION TO GET SPACEGROUPS - #use CSD API for now - can make another way if necessary for later use alongside cspy
def get_spacegroups(filename):
    crystal_reader=ccdc.io.CrystalReader(filename)
    crystal = crystal_reader[0]
    spacegroup = crystal.spacegroup_number_and_setting
    if spacegroup[1]>2:
       print('ASE cannot handle this setting - conversion needed')
       sys.exit()
    return spacegroup



###################################################################################
#MAKING INPUTS FROM RES FILES ETC SECTION - Uses res files as outlined in docuementation
#also obtains energy and asymmetric unit length for later use
####################################################################################
def file_neatener(filename):
        with open(filename, "r") as f:
             lines = f.readlines()
#Read in total energy from top line, convert to au per atom and add to list
        energy_line = lines[0].split()
        total_energy = energy_line[2]
        atom_energy = (float(total_energy))
#one less/start part identifies where the co-ordinate part begins-i.e which lines to edit
        search = 'SFAC'
        one_less = [lines.index(line) for line in lines if search in line]
        start = one_less[0] + 1
        asymm_length = len(range(start,(len(lines))))
    #looping over coordinate section lines to be edited
   # edit the lines to add zeros columns and remove the numbers from the element coloumn

        for i in range(start,(len(lines))):
            col_list = (lines[i]).split()
        #take only the leters from the first column
            element=[char for char in col_list[0] if char.isalpha()]
            col_list[0]="".join(str(x) for x in element)
        #add two columns of zeros
            col_list.append(0)
            col_list.append(0)
            str_col_list = [str(col) for col in col_list]
            lines[i]= " ".join(str_col_list) + "\n"

        return (lines,atom_energy,asymm_length)

    
    
#FUNCTION TO MAKE ATOMS OBJECT
def make_atoms(lines,spacegroup):
    with open('correct_res.res','w') as f:
         f.writelines(lines)
    atom_struc=aseio.read('correct_res.res')
    spacegrp=spacegroup[0]
    print('spacegroup detected')
    print(spacegrp)
    setting=spacegroup[1]
 # add the spacegroup info and use crystal structure to apply bulk/crystal info to the releavant atoms object
    atom_struc = crystal(symbols=atom_struc,spacegroup=spacegrp,setting=setting,pbc=True)
    os.remove('correct_res.res')
    return atom_struc




#FUNCTION TO DO OVERALL MAKING INPUTS - start with zipfile of res structures
def make_inputs(structures_zip_name):
    atoms_list=[]
    energy_list=[]
    asymm_list=[]
    with ZipFile(structures_zip_name,'r') as structures_zip:
        #get list of filenames
         structure_list = structures_zip.namelist()
         for i in  range(len(structure_list)):
             filename = structure_list[i]
             structures_zip.extract(filename)
             spacegroup = get_spacegroups(filename)
             neatener=file_neatener(filename)
             os.remove(filename)
             lines=neatener[0]
             energy=neatener[1]
             asymm_length=neatener[2]
             struc_atoms = make_atoms(lines,spacegroup)
             #Append each entry to corresponding list
             atoms_list.append(struc_atoms)
             energy_list.append(energy)
             asymm_list.append(asymm_length)
         return (atoms_list,energy_list,asymm_list)



####################################################################################################################################################################################################
#KERNEL CALCULATION SECTION
#####################################################################################################################################################################################################
#FUNCTION TO 'SETUP KERNEL' WITH IMPORTANT PARAMATERS OF CHOICE
def initialise_kernel(cut_off):
    HYPERS = {'soap_type': 'PowerSpectrum','interaction_cutoff': cut_off,'max_radial': 8,'max_angular': 6,'gaussian_sigma_constant': 0.3,'gaussian_sigma_type': 'Constant','cutoff_smooth_width': 0.5,'radial_basis': 'GTO','inversion_symmetry': True,'normalize' : True}

    features = SOAP(**HYPERS)
    kernel = rascal.models.Kernel(features, kernel_type='Full', target_type='Structure',zeta=1)
    return (kernel,features)


def make_kernel(atoms_list,kernel,features,mol_name):
    struct = atoms_list
    for s in struct:
        s.wrap(eps=1e-18)
    struct = features.transform(struct)
    results=kernel.__call__(struct)
    normalised_results=results.copy()
    for i in range(0,results.shape[0]):
        for j in range(0,results.shape[0]):
             normalised_results[i,j]=results[i,j]/((results[i,i]*results[j,j])**0.5)
    np.save('Normalised_basic_kernel_' + mol_name + '.npy',normalised_results)
    return normalised_results

##############################################
#GETTING DRESSED ENERGIES FROM KERNEL SECTION
##############################################

##############################
#PROJECTION MAKING FUNCTIONS
##############################
#kpca function from original GCH repository (https://github.com/andreanelli/GCH)  - including possible centering step error
def kpca(kernel,ndim):
    """ Extracts the first ndim principal components in the space
    induced by the reference kernel (Will expect a square matrix) """
    #Centering step
    k = kernel.copy()
    cols=np.mean(k,axis=0);
    rows=np.mean(k,axis=1);
    mean=np.mean(cols);
    for i in range(len(k)):
        k[:,i]-=cols
        k[i,:]-=rows
    k += mean
    # Eigensystem step
    eval, evec = salg.eigh(k,eigvals=(len(k)-ndim,len(k)-1))
    eval=np.flipud(eval); evec=np.fliplr(evec)
    print(eval)
    pvec = evec.copy()
    print(pvec)
    for i in range(ndim):
        pvec[:,i] *= 1./np.sqrt(eval[i])

    # Projection step
    return np.dot(k, pvec)
    
#function to take kernel, make projection, and join it with the energies of choice
def make_kpca_data(kernel,proj_size,energies):
    kern=kernel
    projection=kpca(kern,proj_size)
    projection =pd.DataFrame(projection)
    energies =pd.DataFrame(energies)
    data=pd.concat((energies,projection),axis=1,ignore_index=True)
    data=data.to_numpy()
    return data
    
#function to do overall kpca_generation - ideally don't need this as a seperate function
def do_kpca_process(kernel,proj_size,energy_array):
    data = make_kpca_data(kernel,proj_size,energy_array)

    return data



##################################################
# HULL TAKING/DRESSED ENERGY DATA ETC FUNCTIONS  -these sections involve code adapted from: https://github.com/andreanelli/GCH
#################################################

#function to cut down the data to desired dimensions and take required hull
def take_hull(data,dimensions):
    points = data[:,0:dimensions]
    hull = chull(points)
    return (hull,points)

#function to get chemically relevant hull - based on facet equations
def get_relevant(hull):
    slist = hull.simplices
    snormals=hull.equations
    bad_equations=[]
    vlist=[]
    for i in range(len(slist)):
        if snormals[i,0] > 0.:
           bad_equations.append(i)
        else:
           vlist = np.union1d(vlist, slist[i])
    equations = np.delete(snormals, bad_equations, axis=0)
    return (equations,vlist)

#############################################################################
#functions to get dressed energies - includes functions adapted from DCH code:https://github.com/scikit-learn-contrib/scikit-matter/blob/main/src/skmatter/sample_selection/_base.py 
#functions get_all_distances and get_dressed taken/adapted from DCH in scikitmatter and are subject to the following copyright
"""Copyright (c) 2020 the sklearn-matter contributors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."""
############################################################################
def get_all_distances(points,equations):
    """
    Computes the distance of the points to the planes defined by the equations
    with respect to the direction of the first dimension.
    equations : ndarray of shape (n_facets, n_dim)
                each row contains the coefficienst for the plane equation of the form
                equations[i, 0]*x_1 + ...
                    + equations[i, -2]*x_{n_dim} = equations[i, -1]
                -equations[i, -1] is the offset
    points    : ndarray of shape (n_samples, n_dim)
                points to compute the directional distance from
    Returns
    -------
    directional_distance : ndarray of shape (nsamples, nequations)
                closest distance wrt. the first dimension of the point to the planes
                defined by the equations
    """
    orthogonal_distances = -(points @ equations[:, :-1].T) - equations[:, -1:].T
    return -orthogonal_distances / equations[:, :1].T

#function to get dressed energies (closest distances to hull for all points)

def get_dressed(points,equations):
    distances = get_all_distances(points,equations)
    # we get negative distances for each plane to check if any distance is below the threshold
    below_directional_convex_hull = np.any(distances < -0.000001, axis=1)
    # directional distances to corresponding plane equation 
    dressed_energies = np.zeros(len(points))
    dressed_energies[~below_directional_convex_hull] = np.min(distances[~below_directional_convex_hull], axis=1)
    # some distances can be negative if tolerances allow it to be outside of hull, so we take the max of all negative distances for the corresponding 
    # point to be the dressed energy in that case
    negative_directional_distances = distances.copy()
    negative_directional_distances[distances > 0] = -np.inf
    dressed_energies[below_directional_convex_hull] = np.max(negative_directional_distances[below_directional_convex_hull], axis=1)
    return dressed_energies






#function to get dressed energies if starting from a projection
def get_dressed_energies_analysis(projection,dimensions):
     hull=take_hull(projection,dimensions)
    #get relevant part for hull
     relevant=get_relevant(hull[0])
     equations=relevant[0]
     points=hull[1]
     dressed_energies=get_dressed(points,equations)
     return dressed_energies




#function to get dressed energies if starting from kernel
def get_dressed_energies_full(kernel,proj_size,energy_array,dimensions,mol_name):
          data=do_kpca_process(kernel,proj_size,energy_array)
          projection_name='kPCA_projection_for_average_possibilities_'+ mol_name + '.npy'
          np.save(projection_name,data)
          dressed_energies=get_dressed_energies_analysis(data,dimensions)
          return dressed_energies



##########################################
#FUNCTION TO RUN COMPLETE PROCESS
###########################################
def from_start_to_end(mol_name,struc_zip,proj_size,cut_off,dimensions,job_type,projection_file):
    if job_type=='full':
       inputs=make_inputs(struc_zip)
       atoms_list=inputs[0]
       energy_list=inputs[1]
   
       initial=initialise_kernel(cut_off)
       kernel=initial[0]
       features=initial[1]
       final_kernel=make_kernel(atoms_list,kernel,features,mol_name)

       proj=min([proj_size,final_kernel.shape[0]])

       values=get_dressed_energies_full(final_kernel,proj,energy_list,dimensions,mol_name)
       dressed_name='dressed_energies_' + mol_name + '_' + str(cut_off) + '.npy'
       np.save(dressed_name,values)

    if job_type=='analysis':
       projection=np.load(projection_file)
       values=get_dressed_energies_analysis(projection,dimensions)
       dressed_name='dressed_energies_' + mol_name + '_' + str(cut_off) + '.npy'
       np.save(dressed_name,values)
    return(values)


##########################################################################################
parser = argparse.ArgumentParser(
                    prog='GCH Dressed Energy Generator',
                    description='Takes inputs of the molecular geometry xyz and a zip-file of res files of predicted structures and returns an array of hull energies for each structure (ordered by original order split into Z prime sections)')

parser.add_argument('-sz','--struc_zip',default='none',help='Zip file of predicted structure res files.')
parser.add_argument('-mn', '--mol_name',default='none',help='name you want to give to the system')      
parser.add_argument('-tol', '--tolerance',type=float,default=0.3,help='Tolerance value for calculation of atom index mappings (default=0.3)')
parser.add_argument('-ps', '--proj_size',type=int,default=32,help='Number of kPCA components to be calculated (default=32, unlikely to need bigger)')
parser.add_argument('-co','--cut_off',type=float,default=4,help='SOAP cut-off for descriptor calculation in Angstroms (default=4)')
parser.add_argument('-d','--dimensions',type=int,default=2,help='Number of dimensions - including energy dimension- desired for hull construction (default=2)')
parser.add_argument('-jt','--job_type',type=str,default='full',help='Type of job to run - either "full" (full process) or "analysis" - (just dressed energy calculations - you must provide projection) ')
parser.add_argument('-pj','--projection_file',type=str,default='none',help='File containing the ready-made projection - only for use if job-type is analysis')


args = parser.parse_args()

#############
#RUN PROCESS
#############

answer = from_start_to_end(args.mol_name,args.struc_zip,args.proj_size,args.cut_off,args.dimensions,args.job_type,args.projection_file)
print(answer)

end_time=time.time()
run_time=end_time-start_time
print('actual python took:', run_time, 'seconds')

########################################################################################

#THE END
