########################################################################
#Adapted Kernel Generation Script
#incoporates code adapted form the original gch libraries: https://github.com/andreanelli/GCH
#and the DirectionalConvexHull code in: https://github.com/scikit-learn-contrib/scikit-matter/blob/main/src/skmatter/sample_selection/_base.py 
#alongside unique code
#####################################################################
#IMPORTS SECTION
#####################################################################
import sys
import os
import time
from multiprocessing import Pool
from functools import partial
import argparse


import numpy as np
from pymatgen import io
from pymatgen.io import xyz
from pymatgen import symmetry
from pymatgen.symmetry import analyzer


import os
import ccdc
from ccdc import descriptors
import ase
import ase.io as aseio
from ase.spacegroup import crystal
from ase.symbols import Symbols
import rascal
from rascal.representations import SphericalInvariants as SOAP
from rascal.neighbourlist.structure_manager import (mask_center_atoms_by_id)
from zipfile import ZipFile


import itertools

import collections
import pandas as pd
import scipy.linalg as salg
from scipy.spatial import ConvexHull as chull


###################################################
#start timing code - for possible analysis purposes
#################################################
start_time=time.time()



###########################################################
#Functions to find symmetry mappings of molecule
#########################################################

#READ MOLECULE FUNCTION

def read_molecule(underlying_mol):
    pymat_read=io.xyz.XYZ.from_file(underlying_mol)
    pymat_mol=pymat_read.molecule
    return pymat_mol

#WRITE TO FILE FUNCTION

def write_mol(mol_xyz,mol_name):
    mol_xyz.write_file(mol_name)


#GET CENTERED MOLECULE FUNCTION 

def get_centered(underlying_mol,to_centre):
    centered=to_centre.get_centered_molecule()
    centered_xyz=io.xyz.XYZ(centered)
    #WRITING STEP - ideally could be removed for more optimal code
    mol_name='centered_' + underlying_mol.split('.')[0] + '.xyz'
    write_mol(centered_xyz,mol_name)
    return (centered,mol_name)


#GET OPERATORS FUNCTION

def get_operators(centered_mol):
    analyser=symmetry.analyzer.PointGroupAnalyzer(centered_mol)
    operators=analyser.get_symmetry_operations()
    return operators


#APPLY TRANSFORMATIONS FUNCTIONS - in more optimal code could be done
#smoother wih less read/write

#getting vectors representing atom positions and species
def get_atom_vectors(centered_file):
    with open(centered_file,'r') as f:
          lines=f.readlines()
          atom_vectors=[]
          for i in range(2,len(lines)):
               atom_coordinate_list = lines[i].split()[1:]
               atom_vector= [float(x) for x in atom_coordinate_list]
               atom_el = lines[i].split()[0]
               atom_vector.insert(0,atom_el)
               atom_vectors.append(atom_vector)
               mol_atoms=len(atom_vectors)  
    return (atom_vectors,mol_atoms)
        
#doing transfromation on on an atom vector
def transform_atom(atom_vector,operator):
       #get co-ords
            coords = atom_vector[1:]
            coords = np.array(coords)
       #make new_vector
            new_vec = [atom_vector[0]]
       #rotating
            rot_matrix = operator.rotation_matrix
            rotated_atom = np.matmul(rot_matrix,coords)

            #translating
            tau = operator.translation_vector
            moved_atom = rotated_atom + tau
            
            #new vector updating
            for i in moved_atom:
                new_vec.append(i)
            return new_vec
              
#transforming all atoms in molecule
def transform_molecule(atom_vectors,operator):
    new_vecs=[]
    for vec in atom_vectors:
        new_vec=transform_atom(vec,operator)
        new_vecs.append(new_vec)
    return new_vecs

#writing transformed molecule file
def write_trans(new_vecs,new_filename):
    with open(new_filename,'a') as f:
         f.writelines(str(len(new_vecs)) + '\n')
         atom_lines=[]
         for new_vec in new_vecs:
             str_changed_atom = [str(element) for element in new_vec]
             atom_line = join(str_changed_atom) + "\n"
             atom_line.append(atom_lines)
         f.writelines(atom_lines)


#FIND MATCHUPS FUNCTIONS - to identify atom-atom mappings between original and transformed

#set up method for writing the match lists
def write_match(match_list,map_string,i,j):
    if i==j:
       match = [i]
       match_list.append(match)
       map_string = map_string + str(i) + '_'
    else:
       match = [i,j]
       match_list.append(match)
       map_string = map_string + str(i) + str(j) + '_'
    return (match_list,map_string)

#set up method to gather operator matches and check for duplication
def gather_matches(all_maps,all_map_strings,match_list,map_string):
    if map_string not in all_map_strings:
           all_map_strings.append(map_string)
           all_maps.append(match_list)
    return(all_maps,all_map_strings)

#find the matches - based on shared coordinates and atom species
def get_matches(atom_vectors,new_vecs,tolerance):
    match_list=[]
    map_string='string-'
    for i in range(len(atom_vectors)):
        for j in range(len(new_vecs)):
            if (atom_vectors[i])[0]==(new_vecs[j])[0]:
               old=np.array((atom_vectors[i])[1:])
               new=np.array((new_vecs[j])[1:])
               diff=abs(old-new)
               if all(x <= tolerance for x in diff):
                  update=(write_match(match_list,map_string,i,j))    
                  match_list=update[0]
                  map_string=update[1]
    return (match_list,map_string)     
         
#OVERALL FUNCTION TO RUN THE MAPPINGS SECTION

def overall_map(underlying_mol,tolerance):
    #initialise all_map_strings and all_maps
    all_maps=[]
    all_map_strings=[]
    #read in molecule
    pymat_mol=read_molecule(underlying_mol)
    #Centre molecule
    centered_mol=get_centered(underlying_mol,pymat_mol)[0]
    centered_file=get_centered(underlying_mol,pymat_mol)[1]
    #Get operators
    operators=get_operators(centered_mol)
    #get atom vectors and mol atoms for original
    atom_vectors_run=get_atom_vectors(centered_file)
    atom_vectors=atom_vectors_run[0]
    molecule_atoms=atom_vectors_run[1]
    #Do transformations and mapping for each operator
    for k in range(len(operators)):
        operator=operators[k]
        new_vecs=transform_molecule(atom_vectors,operator)
        match_step=get_matches(atom_vectors,new_vecs,tolerance)
        match_list=match_step[0]
        map_string=match_step[1]
        #gather up all the mappings
        gather=gather_matches(all_maps,all_map_strings,match_list,map_string)
        all_maps=gather[0]
        all_map_strings=gather[1]
    #save the mapping data
    all_maps_array=np.array(all_maps,dtype=object)
    np.save('all_maps.npy',all_maps_array)
    return (all_maps,molecule_atoms)





####################################################################################
#SPACEGROUP/SETTINGS FINDER SECTION - needed for setting up atoms objects
####################################################################################
#FUNCTION TO GET SPACEGROUPS
def get_spacegroups(filename):
    crystal_reader=ccdc.io.CrystalReader(filename)
    crystal = crystal_reader[0]
    spacegroup = crystal.spacegroup_number_and_setting
    if spacegroup[1]>2:
       print('ASE cannot handle this setting - conversion needed')
       sys.exit()
    return spacegroup

###################################################################################
#MAKING INPUTS FROM RES FILES ETC SECTION - Uses res files as outlined in docuementation
#also obtains energy and asymmetric unit length for later use
####################################################################################
def file_neatener(filename):
        with open(filename, "r") as f:
             lines = f.readlines()
#Read in total energy from top line, convert to au per atom and add to list
        energy_line = lines[0].split()
        total_energy = energy_line[2]
        atom_energy = (float(total_energy))
#one less/start part identifies where the co-ordinate part begins-i.e which lines to edit
        search = 'SFAC'
        one_less = [lines.index(line) for line in lines if search in line]
        start = one_less[0] + 1
        asymm_length = len(range(start,(len(lines))))
    #looping over coordinate section lines to be edited
   # edit the lines to add zeros columns and remove the numbers from the element coloumn

        for i in range(start,(len(lines))):
            col_list = (lines[i]).split()
        #take only the leters from the first column
            element=[char for char in col_list[0] if char.isalpha()]
            col_list[0]="".join(str(x) for x in element)
        #add two columns of zeros
            col_list.append(0)
            col_list.append(0)
            str_col_list = [str(col) for col in col_list]
            lines[i]= " ".join(str_col_list) + "\n"

        return (lines,atom_energy,asymm_length)

    
    
#FUNCTION TO MAKE ATOMS OBJECT
def make_atoms(lines,spacegroup):
    with open('correct_res.res','w') as f:
         f.writelines(lines)
    atom_struc=aseio.read('correct_res.res')
    spacegrp=spacegroup[0]
    print('spacegroup detected')
    print(spacegrp)
    setting=spacegroup[1]
 # add the spacegroup info and use crystal structure to apply bulk/crystal info to the releavant atoms object
    atom_struc = crystal(symbols=atom_struc,spacegroup=spacegrp,setting=setting,pbc=True)
    os.remove('correct_res.res')
    return atom_struc




#FUNCTION TO DO OVERALL MAKING INPUTS - start with zipfile of res structures
def make_inputs(structures_zip_name):
    atoms_list=[]
    energy_list=[]
    asymm_list=[]
    with ZipFile(structures_zip_name,'r') as structures_zip:
        #get list of filenames
         structure_list = structures_zip.namelist()
         for i in  range(len(structure_list)):
             filename = structure_list[i]
             structures_zip.extract(filename)
             spacegroup = get_spacegroups(filename)
             neatener=file_neatener(filename)
             os.remove(filename)
             lines=neatener[0]
             energy=neatener[1]
             asymm_length=neatener[2]
             struc_atoms = make_atoms(lines,spacegroup)
             #Append each entry to corresponding list
             atoms_list.append(struc_atoms)
             energy_list.append(energy)
             asymm_list.append(asymm_length)
         return (atoms_list,energy_list,asymm_list)


            
########################################################################
# RE-ORDERING FUNCTION - needed later to reorder energies properties based on 
#Z', as structures are reordered that way
##########################################################################

def reorderer(property_list,asymm_list,Z_prime_list,molecule_atoms):
    prop_lists=dict([(k, []) for k in Z_prime_list])
    for i in range(len(asymm_list)):
        if asymm_list[i]/molecule_atoms in Z_prime_list:
            prop_lists[(int(asymm_list[i]/molecule_atoms))].append(property_list[i])
        else:
            print('reordering error-structure is of unfamilar z_prime')
            sys.exit()
    ordered_prop_lists=collections.OrderedDict(sorted(prop_lists.items(), key=lambda t: t[0]))
    reordered_list=np.concatenate(list(ordered_prop_lists.values()))
    return reordered_list



############################################################################################
#KERNEL CALCULATION SECTION - has functions needed for SOAP descriptor and kernel calculation
##############################################################################################
#FUNCTION TO 'SETUP KERNEL' WITH IMPORTANT PARAMATERS OF CHOICE 
def initialise_kernel(cut_off):
    HYPERS = {'soap_type': 'PowerSpectrum','interaction_cutoff': cut_off,'max_radial': 8,'max_angular': 6,'gaussian_sigma_constant': 0.3,'gaussian_sigma_type': 'Constant','cutoff_smooth_width': 0.5,'radial_basis': 'GTO','inversion_symmetry': True,'normalize' : True}

    features = SOAP(**HYPERS)
    kernel = rascal.models.Kernel(features, kernel_type='Full', target_type='Structure',zeta=1)
    return (kernel,features)


#FUNCTION TO SORT STRUCTURES BY Z_PRIME -
def sort_z_prime(atoms_list,asymm_lengths,mol_name,molecule_atoms):
    z_prime_list=[]
    mini_lists=[]
    for i in range(0,len(atoms_list)):
        #establish z-prime for each structure
        atoms = atoms_list[i]
        asymm = asymm_lengths[i]
        z_prime = int(asymm/molecule_atoms)
        if z_prime != asymm/molecule_atoms:
           print('structure is of non-integer Z prime. This dataset cannot be handled')
           sys.exit()
        #add structure to relevant mini list amd record z prime value
        if z_prime not in z_prime_list:
           z_prime_list.append(z_prime)
           mini_lists.append((z_prime,[]))
        for mini in mini_lists:
            if mini[0]==z_prime:
               mini[1].append(atoms)
    for mini in mini_lists:
        filename = 'z_prime' + str(mini[0]) + '_' + mol_name + '.xyz'
        aseio.write(filename,mini[1])
    return (z_prime_list,mini_lists)


#FUNCTION TO CALC NUMBER OF MOLECULES IN A STRUCTURE (Z) - needed for identifying atoms later
def get_num_mol(single_struc,molecule_atoms):
    all_atoms = len(single_struc)
    num_mol =int (all_atoms/molecule_atoms)
    return num_mol


#FUNCTION TO TAKE THE ATOM INDEX AND RETURN THE ATOM IDS THAT NEED TO BE MASKED IN THE ORIGINAL STRUCTURE
def to_mask(atom_index,zp,num_mol,molecule_atoms):
    masked = []
    number_asymm = int(num_mol/zp)
    #Get overall index of molecular atom index in each molecule in assymetric unit, accounting for possible multiple copies of asymm unit in file
    for i in range(zp):
        to_mask_start =int((atom_index + molecule_atoms*i) * number_asymm)
        to_mask = [i for i in range(to_mask_start, (to_mask_start + 1))]
        masked = masked + to_mask
        print('masking in A')
        print(masked)
    return masked



#FUNCTION TO TAKE THE ATOM INDICES (FROM MAPPING LISTS)  AND  RETURN THE ATOM IDS THAT NEED TO BE MASKED IN THE TRANSFORMED  STRUCTURE

def to_mask_adv(atom_index_list,zp,num_mol,molecule_atoms):
    masked = []
    number_asymm = int(num_mol/zp)
    #sanity check- atom_index_list is of correct length
    try:
       atom_index_list[zp-1]
    except IndexError:
       print('Atom_index_list is too short. It should have Z prime elements')

    try:
       check_list = ['check']*zp
       check_list[(len(atom_index_list)-1)]
    except IndexError:
       print('Atom_index_list is too long. It should have Z prime elements')
    #work out overall file inidices for the desired molecular atom indices in each
    #same process as before but masking off a different molecular atom index (according to desired mappings) in each molecule of asymmetric unit
    for i in range(zp):
        to_mask_start =int((atom_index_list[i] + molecule_atoms*i) * number_asymm)
        to_mask = [i for i in range(to_mask_start, (to_mask_start + 1))]
        masked = masked + to_mask
        print('masking in B')
        print(masked)
    return masked


#FUNCTION TO GET Z' COMBINATIONS OF POSSIBLE MAPPINGS - Temporarily defunct as causes issues

def get_combos(atom_index,zp_max,symm_opp_list):
    #Generate the list of indices that could be equivalent to a given atom_index, one mapping at a time
    possible_equivalents = []
    #To generate possible equivalents take atom index - search for it in each  mapping - pull out what it maps to and add to equivalents list
    for symm_opp in symm_opp_list:
        for i in symm_opp:
            if i[0] == atom_index:
               try:
                   possible_equivalents.append(i[1])
               except IndexError:
                   possible_equivalents.append(i[0])

    #Generate the posible z_prime combinations of possible equivalent indices
    if zp_max > 1:
       symm_combos=[combo for combo in itertools.product(possible_equivalents,repeat=zp_max)]
    if zp_max == 1:
       symm_combos = [[combo] for combo in possible_equivalents]

    return symm_combos



#SET UP PARTIAL FUNCTION TO DO KERNEL CALCULATION FOR EACH MAPPING FOR SINGLE ORIGINAL ATOM INDEX (or Z' combination of mappings) 

def one_symm(symm_combo,struct_A,struct_B,file_A,file_B,molecule_atoms,zp_max,zp_min,atom_index,features,kernel):
            combo = symm_combo
            combo = [atom for atom in combo]
        #Mask off atoms corresponding to the relevant atom index in original structure - relies on to mask function
            for s in struct_A:
                s.wrap(eps=1e-18)
                all_atoms = len(s)
                number_molecules =int(get_num_mol(s,molecule_atoms))
                to_masks = to_mask(atom_index,zp_min,number_molecules,molecule_atoms)
                mask_center_atoms_by_id(s,to_masks)
            for s in struct_B:
           #mask off attoms corresponding to the symmetry combination in transformed structure
                s.wrap(eps=1e-18)
                all_atoms = len(s)
                number_molecules =int(get_num_mol(s,molecule_atoms))
                to_masks = to_mask_adv(combo,zp_max,number_molecules,molecule_atoms)
                mask_center_atoms_by_id(s,to_masks)
        #Take kernel for atom contribution and relevant symmetry combination - add to corresponding total kernel
            struct_A = features.transform(struct_A)
            struct_B = features.transform(struct_B)
            print('show descriptors')
            print(struct_A.get_features(features))
            print(struct_B.get_features(features))
            results=kernel.__call__(struct_A,struct_B)
        #reset the structure lists
            struct_A=[]
            struct_B=[]
            struct_A = aseio.read(file_A,index=':')
            struct_B = aseio.read(file_B,index=':')
            print('done symm opp ')

            return results

#FUNCTION TO GET 'FULL' ADAPTED KENREL FOR EACH Z' PAIR
def zprime_pair_kernel(zp_max,zp_min,file_A,file_B,symm_opp_list,molecule_atoms,features,kernel,cores_requested):

    #Read in 'mini structure lists' for each z_prime
    struct_A = aseio.read(file_A,index=':')
    struct_B = aseio.read(file_B,index=':')
    number_combos = len(symm_opp_list)**zp_max
    total_kernels=[np.zeros([len(struct_A),len(struct_B)])]*number_combos
    #Loop over contributions from each atom in the molecule
    for atom_index in range(0,molecule_atoms):
        #Generate the possible symmetry combos - in terms of molecular atom indexes
        symm_combos=get_combos(atom_index,zp_max,symm_opp_list)
        cores_required=len(symm_combos)
        core_options=[cores_required,cores_requested]
        cores_needed=min(core_options)
        print('got_combos')
        #perform the single atom contribution kernel calculations
        with Pool(int(cores_needed)) as pool:
             symm_combo = partial(one_symm,struct_A=struct_A,struct_B=struct_B,file_A=file_A,file_B=file_B,zp_max=zp_max,zp_min=zp_min,molecule_atoms=molecule_atoms,atom_index=atom_index,features=features,kernel=kernel)


             atom_kernels = pool.map(symm_combo,symm_combos)

        for i in range(0,number_combos):
            total_kernels[i] = total_kernels[i] + atom_kernels[i]
        print('done ', atom_index)
    #average over the atom index contributions
    average_kernels=[np.zeros([len(struct_A),len(struct_B)])]*number_combos
    for i in range(0, len(symm_combos)):
        average_kernels[i] = total_kernels[i]/molecule_atoms
    average_kernels_array = np.array(average_kernels)

    #Take the mean value for each structure pair over all the possible symmetry combinations
    final_kernel = get_mean_kernel(average_kernels,len(struct_A),len(struct_B))
    return final_kernel

#Function to get average of possibilities kernel (over different mappings)

def get_mean_kernel(kernel_possibilities,length_A,length_B):
    correct_kernel=np.zeros([length_A,length_B])
    #For each structure pair, loop over the possibilities and take the mean
    for A in range(0,length_A):
        for B in range(0,length_B):
            possibilities = [chance[A,B] for chance in kernel_possibilities]
            value = np.mean(possibilities)
            correct_kernel[A,B] = value
    return correct_kernel

#function to run the kernel calculator to make each z' pair  kernel

def make_all_kernel_bits(Z_prime_list,mol_name,symm_opp_list,molecule_atoms,features,kernel,cores_requested):
    for zp1 in Z_prime_list:
       for zp2 in Z_prime_list:
            if zp1 >= zp2:
               file_B = 'z_prime' + str(zp1) + '_' + mol_name + '.xyz'
               file_A = 'z_prime' + str(zp2) + '_' + mol_name + '.xyz'
               zprime_kernel = zprime_pair_kernel(zp1,zp2,file_A,file_B,symm_opp_list,molecule_atoms,features,kernel,cores_requested)
               np.save('z_prime_'+ str(zp1) + '_' + str(zp2) + '_' + mol_name + '.npy',zprime_kernel)

#############################################################################################
#KERNEL RECOMBINATION SECTION
#########################################################################################

def kernel_combiner(z_prime_list,mol_name):
      #piece together kernel sections in Z' order
      made_full=False
      for i in z_prime_list:
          made_column=False
          for j in z_prime_list:
              if i <= j:
                 filename = 'z_prime_' + str(j) +'_' + str(i) + '_' + mol_name + '.npy'
                 part = np.transpose(np.load(filename))
              else:
                 filename = 'z_prime_' + str(i) +'_' + str(j) + '_' + mol_name + '.npy'
                 part = np.load(filename)
              if made_column:
                 column = np.concatenate((column,part), axis=0)
              else:
                 column=part
                 made_column=True
          if made_full:
             full=np.concatenate((full,column), axis=1)
          else:
             full=column
             made_full=True
      normalised = full.copy()
      #normalise kernels
      for i in range(0,full.shape[0]):
          for j in range(0,full.shape[0]):
              normalised[i,j] = full[i,j]/((full[i,i]*full[j,j])**0.5)

      kernel_name='final_average_possibilities_kernel_' + mol_name + '.npy'
      np.save(kernel_name,normalised)
      return normalised


##############################################
#GETTING DRESSED ENERGIES FROM KERNEL SECTION
##############################################

##############################
#PROJECTION MAKING FUNCTIONS
##############################
#kpca function from original GCH repository (https://github.com/andreanelli/GCH)  - including possible centering step error
def kpca(kernel,ndim):
    """ Extracts the first ndim principal components in the space
    induced by the reference kernel (Will expect a square matrix) """
    #Centering step
    k = kernel.copy()
    cols=np.mean(k,axis=0);
    rows=np.mean(k,axis=1);
    mean=np.mean(cols);
    for i in range(len(k)):
        k[:,i]-=cols
        k[i,:]-=rows
    k += mean
    # Eigensystem step
    eval, evec = salg.eigh(k,eigvals=(len(k)-ndim,len(k)-1))
    eval=np.flipud(eval); evec=np.fliplr(evec)
    print(eval)
    pvec = evec.copy()
    print(pvec)
    for i in range(ndim):
        pvec[:,i] *= 1./np.sqrt(eval[i])

    # Projection step
    return np.dot(k, pvec)
    
#function to take kernel, make projection, and join it with the energies of choice
def make_kpca_data(kernel,proj_size,energies):
    kern=kernel
    projection=kpca(kern,proj_size)
    projection =pd.DataFrame(projection)
    energies =pd.DataFrame(energies)
    data=pd.concat((energies,projection),axis=1,ignore_index=True)
    data=data.to_numpy()
    return data
    
#function to do overall kpca_generation - ideally don't need this as a seperate function
def do_kpca_process(kernel,proj_size,energy_array):
    data = make_kpca_data(kernel,proj_size,energy_array)

    return data

        
################################################## 
# HULL TAKING/DRESSED ENERGY DATA ETC FUNCTIONS  -these sections involve code adapted from: https://github.com/andreanelli/GCH   
#################################################

#function to cut down the data to desired dimensions and take required hull 
def take_hull(data,dimensions):
    points = data[:,0:dimensions]
    hull = chull(points)
    return (hull,points)

#function to get chemically relevant hull - based on facet equations
def get_relevant(hull):
    slist = hull.simplices
    snormals=hull.equations
    bad_equations=[]
    vlist=[]
    for i in range(len(slist)):
        if snormals[i,0] > 0.:
           bad_equations.append(i)       
        else:
           vlist = np.union1d(vlist, slist[i])
    equations = np.delete(snormals, bad_equations, axis=0)
    return (equations,vlist)





#############################################################################
#functions to get dressed energies - includes functions adapted from DCH code:https://github.com/scikit-learn-contrib/scikit-matter/blob/main/src/skmatter/sample_selection/_base.py 
#functions get_all_distances and get_dressed taken/adapted from DCH in scikitmatter and are subject to the following copyright
"""Copyright (c) 2020 the sklearn-matter contributors
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
  list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
  this list of conditions and the following disclaimer in the documentation
  and/or other materials provided with the distribution.

* Neither the name of the copyright holder nor the names of its
  contributors may be used to endorse or promote products derived from
  this software without specific prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE."""
############################################################################
def get_all_distances(points,equations):
    """
    Computes the distance of the points to the planes defined by the equations
    with respect to the direction of the first dimension.
    equations : ndarray of shape (n_facets, n_dim)
                each row contains the coefficienst for the plane equation of the form
                equations[i, 0]*x_1 + ...
                    + equations[i, -2]*x_{n_dim} = equations[i, -1]
                -equations[i, -1] is the offset
    points    : ndarray of shape (n_samples, n_dim)
                points to compute the directional distance from
    Returns
    -------
    directional_distance : ndarray of shape (nsamples, nequations)
                closest distance wrt. the first dimension of the point to the planes
                defined by the equations
    """
    orthogonal_distances = -(points @ equations[:, :-1].T) - equations[:, -1:].T
    return -orthogonal_distances / equations[:, :1].T

#function to get dressed energies (closest distances to hull for all points)

def get_dressed(points,equations):
    distances = get_all_distances(points,equations)
    # we get negative distances for each plane to check if any distance is below the threshold
    below_directional_convex_hull = np.any(distances < -0.000001, axis=1)
    # directional distances to corresponding plane equation 
    dressed_energies = np.zeros(len(points))
    dressed_energies[~below_directional_convex_hull] = np.min(distances[~below_directional_convex_hull], axis=1)
    # some distances can be negative if tolerances allow it to be outside of hull, so we take the max of all negative distances for the corresponding 
    # point to be the dressed energy in that case
    negative_directional_distances = distances.copy()
    negative_directional_distances[distances > 0] = -np.inf
    dressed_energies[below_directional_convex_hull] = np.max(negative_directional_distances[below_directional_convex_hull], axis=1)
    return dressed_energies






#function to get dressed energies if starting from a projection
def get_dressed_energies_analysis(projection,dimensions):
     hull=take_hull(projection,dimensions)
    #get relevant part for hull
     relevant=get_relevant(hull[0])
     equations=relevant[0]
     points=hull[1]
     dressed_energies=get_dressed(points,equations)
     return dressed_energies




#function to get dressed energies if starting from kernel
def get_dressed_energies_full(kernel,proj_size,energy_array,dimensions,mol_name):
          data=do_kpca_process(kernel,proj_size,energy_array)
          projection_name='kPCA_projection_for_average_possibilities_'+ mol_name + '.npy'
          np.save(projection_name,data)
          dressed_energies=get_dressed_energies_analysis(data,dimensions)
          return dressed_energies

#####################################
#FUNCTION TO RUN COMPLETE PROCESS
###################################
def from_start_to_end(molecule,mol_name,tolerance,struc_zip,proj_size,cut_off,dimensions,job_type,projection_file,cores_requested):
    if job_type=='full':
       mappings=overall_map(molecule,tolerance)
       symm_opp_list=mappings[0]
       molecule_atoms=mappings[1]

       inputs=make_inputs(struc_zip)
       atoms_list=inputs[0]
       energy_list=inputs[1]
       asymm_list=inputs[2]
   
       initial=initialise_kernel(cut_off)
       kernel=initial[0]
       features=initial[1]
       z_prime_data = sort_z_prime(atoms_list,asymm_list,mol_name,molecule_atoms)
       Z_prime_list=z_prime_data[0]
       kernel_bits=make_all_kernel_bits(Z_prime_list,mol_name,symm_opp_list,molecule_atoms,features,kernel,cores_requested)
    
       final_kernel=kernel_combiner(Z_prime_list,mol_name)
       proj=min([proj_size,final_kernel.shape[0]])
    #reorder the energies list
       correct_energy_list=reorderer(energy_list,asymm_list,Z_prime_list,molecule_atoms)
       energy_name='correct_order_energies_' + mol_name + '.npy'
       np.save(energy_name,correct_energy_list)
       values=get_dressed_energies_full(final_kernel,proj,correct_energy_list,dimensions,mol_name)
       dressed_name='dressed_energies_' + mol_name + '_' + str(cut_off) + '.npy'
       np.save(dressed_name,values)
    if job_type=='analysis':
       projection=np.load(projection_file)
       values=get_dressed_energies_analysis(projection,dimensions)
       dressed_name='dressed_energies_' + mol_name + '_' + str(cut_off) + '.npy'
       np.save(dressed_name,values)
    return(values)



##########################################################################################
parser = argparse.ArgumentParser(
                    prog='GCH Dressed Energy Generator',
                    description='Takes inputs of the molecular geometry xyz and a zip-file of res files of predicted structures and returns an array of hull energies for each structure (ordered by original order split into Z prime sections)')

parser.add_argument('-sz','--struc_zip',default='none',help='Zip file of predicted structure res files.')
parser.add_argument('-mf','--mol_file',default='none',help='xyz file of undelrying molecular geometry.For now can only be one (Rigid CSP)')  
parser.add_argument('-mn', '--mol_name',default='none',help='name you want to give to the system')      
parser.add_argument('-tol', '--tolerance',type=float,default=0.3,help='Tolerance value for calculation of atom index mappings (default=0.3)')
parser.add_argument('-ps', '--proj_size',type=int,default=32,help='Number of kPCA components to be calculated (default=32, unlikely to need bigger)')
parser.add_argument('-co','--cut_off',type=float,default=4,help='SOAP cut-off for descriptor calculation in Angstroms (default=4)')
parser.add_argument('-d','--dimensions',type=int,default=2,help='Number of dimensions - including energy dimension- desired for hull construction (default=2)')
parser.add_argument('-jt','--job_type',type=str,default='full',help='Type of job to run - either "full" (full process) or "analysis" - (just dressed energy calculations - you must provide projection) ')
parser.add_argument('-pj','--projection_file',type=str,default='none',help='File containing the ready-made projection - only for use if job-type is analysis')
parser.add_argument('-j','--core_req',type=int,default=1,help='Number of requested cores to parallelise over - in actual usage terms - only number of cores equal to eventual number of symmetry operator combinations will be used - so limit this to something sensible e.g 20 to avoid too much wastage')



args = parser.parse_args()

#############
#RUN PROCESS
##############

answer = from_start_to_end(args.mol_file,args.mol_name,args.tolerance,args.struc_zip,args.proj_size,args.cut_off,args.dimensions,args.job_type,args.projection_file,args.core_req)
print(answer)

end_time=time.time()
run_time=end_time-start_time
print('actual python took:', run_time, 'seconds')

########################################################################################

#THE END
