#Script to get direct product of groups of mappings from each conformation in set
#also includes some functions/data for prior approach - using direct product of point groups

##################################
#Imports
###############################


from pymatgen import io
from pymatgen.io import xyz
from pymatgen import symmetry
from pymatgen.symmetry import analyzer
from zipfile import ZipFile
import numpy as np
import os
from itertools import chain,combinations


##########################
#Powerset function
###########################
def powerset(set_name):
    s = list(set_name)
    return set(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))


################################################
#Functions to get mappings for a single molecule
###############################################
def read_molecule(underlying_mol):
    pymat_read=io.xyz.XYZ.from_file(underlying_mol)
    pymat_mol=pymat_read.molecule
    return pymat_mol

def write_mol(mol_xyz,mol_name):
    mol_xyz.write_file(mol_name)


def get_centered(underlying_mol,to_centre):
    centered=to_centre.get_centered_molecule()
    centered_xyz=io.xyz.XYZ(centered)
    mol_name='centered_' + underlying_mol.split('.')[0] + '.xyz'
    write_mol(centered_xyz,mol_name)
    return (centered,mol_name)

def get_group(centered_mol):
    analyser=symmetry.analyzer.PointGroupAnalyzer(centered_mol)
    point_group=analyser.get_pointgroup()
    return point_group


def get_symm_one(centered_mol):
    analyser=symmetry.analyzer.PointGroupAnalyzer(centered_mol)
    symm_version=analyser.symmetrize_molecule()
    return symm_version



def overall_group(underlying_mol):
    all_maps=[]
    all_map_strings=[]
    pymat_mol=read_molecule(underlying_mol)
    group = get_group(pymat_mol)
    case_name = (underlying_mol,group.sch_symbol)
    return case_name



def get_operators(centered_mol):
    analyser=symmetry.analyzer.PointGroupAnalyzer(centered_mol)
    operators=analyser.get_symmetry_operations()
    return operators


def transform_atom(atom_vector,operator):
       #get co-ords
            coords = atom_vector[1:]
            coords = np.array(coords)
       #make new_vector
            new_vec = [atom_vector[0]]
       #rotating
            rot_matrix = operator.rotation_matrix
            rotated_atom = np.matmul(rot_matrix,coords)

            #translating
            tau = operator.translation_vector
            moved_atom = rotated_atom + tau

            #new vector make
            for i in moved_atom:
                new_vec.append(i)
            return new_vec


def get_atom_vectors(centered_file):
    with open(centered_file,'r') as f:
          lines=f.readlines()
          atom_vectors=[]
          for i in range(2,len(lines)):#hardoced 2 here is to make sure starts at corrcet line of file
               atom_coordinate_list = lines[i].split()[1:]
               atom_vector= [float(x) for x in atom_coordinate_list]
               atom_el = lines[i].split()[0]
               atom_vector.insert(0,atom_el)
               atom_vectors.append(atom_vector)
               mol_atoms=len(atom_vectors)
    return (atom_vectors,mol_atoms)


def transform_molecule(atom_vectors,operator):
    new_vecs=[]
    for vec in atom_vectors:
        new_vec=transform_atom(vec,operator)
        new_vecs.append(new_vec)
    return new_vecs



#writing transformed
def write_trans(new_vecs,new_filename):
    with open(new_filename,'a') as f:
         f.writelines(str(len(new_vecs)) + '\n')
         atom_lines=[]
         for new_vec in new_vecs:
             str_changed_atom = [str(element) for element in new_vec]
             atom_line = join(str_changed_atom) + "\n"
             atom_line.append(atom_lines)
         f.writelines(atom_lines)


def write_match(match_list,map_string,i,j):
    if i==j:
       match = (i,)
       match_list.append(match)
       map_string = map_string + str(i) + '_'
    else:
       match = (i,j)
       match_list.append(match)
       map_string = map_string + str(i) + str(j) + '_'
    return (match_list,map_string)

def gather_matches(all_maps,all_map_strings,match_list,map_string):
    if map_string not in all_map_strings:
           all_map_strings.append(map_string)
           all_maps.append(match_list)
    return(all_maps,all_map_strings)

def get_matches(atom_vectors,new_vecs,tolerance):
    match_list=[]
    map_string='string-'
    for i in range(len(atom_vectors)):
        for j in range(len(new_vecs)):
            if (atom_vectors[i])[0]==(new_vecs[j])[0]:
               old=np.array((atom_vectors[i])[1:])
               new=np.array((new_vecs[j])[1:])
               diff=abs(old-new)
               if all(x <= tolerance for x in diff):
                  update=(write_match(match_list,map_string,i,j))
                  match_list=update[0]
                  map_string=update[1]
    return (match_list,map_string)


def overall_map(underlying_mol,tolerance):
    #initialise all_map_strings and all_maps
    all_maps=[]
    all_map_strings=[]
    #read in molecule
    pymat_mol=read_molecule(underlying_mol)
    #Centre molecule
    centered_mol=get_centered(underlying_mol,pymat_mol)[0]
    centered_file=get_centered(underlying_mol,pymat_mol)[1]
    #Get operators
    operators=get_operators(centered_mol)
    #get atom vectors and mol atoms for original
    atom_vectors_run=get_atom_vectors(centered_file)
    atom_vectors=atom_vectors_run[0]
    molecule_atoms=atom_vectors_run[1]
    for k in range(len(operators)):
        operator=operators[k]
        new_vecs=transform_molecule(atom_vectors,operator)
        match_step=get_matches(atom_vectors,new_vecs,tolerance)
        match_list=match_step[0]
        map_string=match_step[1]
        #gather up all the mappings
        gather=gather_matches(all_maps,all_map_strings,match_list,map_string)
        all_maps=gather[0]
        all_map_strings=gather[1]
    all_maps_array=np.array(all_maps,dtype=object)
    return (all_maps,molecule_atoms)



#########################################################################
#Eliminating subgroups functions - needed to simplify direct product task
########################################################################

#Powerset function - to get all subgroups of the mapping groups
def powerset(set_name):
    s = list(set_name)
    return set(chain.from_iterable(combinations(s, r) for r in range(len(s)+1)))




#check if group of mappings is subgroup of a group already covered
def check_sub(one_map,covered,max_list):
    removals=[]
    #Make sets of mappings into frozen sets so that can have set of (frozen)sets
    ops=frozenset(tuple(i) for i in one_map)
    #if group of mappings not a subgroup of existing - add it - and all it ssubgroups
    if ops not in covered:
       subs=powerset(ops)
       subs.remove(())
       for sub in subs:
           covered.update([frozenset(sub)])
       #remove any previously 'needed' master groups that are now a subgroup of the new one
       for big in max_list:
           if big in [frozenset(sub) for sub in subs]:
              removals.append(big)
       for bad in removals:
           max_list.remove(bad)
       #add new group of mappings to the list of `masterr' groups needed for direct product calculation
       max_list.append(ops)
    return covered,max_list



#dictionary of point group subgroup - not needed if using groups of mappings method
pg_dict={"C1":['C1'],"Ci":['Ci','C1'],"C2":['C2','C1'],"Cs":['Cs','C1'],"C2h":['C2h','C2','Cs','C1','Ci'],"D2":['D2','C2','C1'],"C2v":['C2v','C2','Cs','C1'],"D2h":['D2h','C2v','D2','C2h','C2','Cs','Ci','C1'],"C4":['C4','C2','C1'],"S4":['S4','C2','C1'],"C4h":['C4h','C4','S4','C2h','C2','Cs','Ci','C1'],"D4":['D4','C4','D2','C2','C1'],"C4v":['C4v','C4','C2v','C2','Cs','C1'],"D2d":['D2d','S4','C2v','D2','C2','Cs','C1'],"D4h":['D4h','D2d','C4v','D4','C4h','C4','S4','D2h','Cev','D2','C2h','Cs','C2','Ci','C1'],"C3":['C3','C1'],"C3i":['C3i','C3','Ci','C1'],"D3":['D3','C3','C2','C1'],"C3v":['C3v','C3','Cs','C1'],"D3d":['D3d','C3v','D3','C3i','C3','C2h','Cs','C2','Ci','C1'],"C6":['C6','C3','C2','C1'],"C3h":['C3h','C3','Cs','C1'],"C6h":['C6h','C3h','C6','C3i','C3','C2h','Cs','C2','Ci','C1'],"D6":['D6','C6','D3','C3','D2','C2','C1'],"C6v":['C6v','C6','C3v','C3','C2v','C2','Cs','C1'],"D3h":['D3h','C3h','C3v','D3','C3','C2v','C2','Cs','C1'],"D6h":['D6h','D3h','C6v','D6','C6h','D3d','C3h','C6','C3v','D3','C3i','C3','D2h','C2v','D2','C2h','C2','Cs','Ci','C1'],"T":['T','C3','D2','C2','C1'],"Th":['Th','T','C3i','C3','D2h','C2v','D2','C2h','C2','Cs','Ci','C1'],"O":['O','T','D3','D4','C4','C3','D2','C2','C1'],"Td":['Td','T','C3v','C3','D2d','S4','C2v','D2','C2','Cs','C1'],"Oh":['Oh','Td','O','Th','T','D3d','C3v','D3','C3i','C3','D4h','D2d','C4v','D4','C4h','S4','C4','D2h','C2v','D2','C2h','C2','Cs','Ci','C1']}



############################################################################
#Take direct product functions
###############################################################################


#function to combine any two MAP (not group of mappings)
def combine_maps(map_1,map_2,molecule_atoms):
    combo_map=[]
    for atom_index in range(molecule_atoms):
        for i in map_1:
            if i[0] == atom_index:
               try:
                   intermed=(i[1])
               except IndexError:
                   intermed=(i[0])
        for j in map_2:
            if j[0] == intermed:
               try:
                   combined=(j[1])
               except IndexError:
                   combined=(j[0])
        if atom_index!=combined:
           match = [atom_index,combined]
        else:
           match = [atom_index]
        combo_map.append(match)
    return combo_map
 





#get required groups of mappings to form product


maps_sets=[]
covered=set(())
max_list=[]



with ZipFile(<ZIPFILE OF CONFORMERS>) as conf_zip:
     names = conf_zip.namelist()
     for name in names:
         conf_zip.extract(name)
         one_map  = overall_map(name,0.3)[0]#0.3 is usually sufficent as a tolerance
         checks=check_sub(one_map,covered,max_list)
         covered=checks[0]
         max_list=checks[1]
         os.remove(name)



#initialise list of groups to `add' to direct product, so can then loop through one by one
maps_sets=max_list.copy()
combo_maps=list(maps_sets[0])

print('length',len(maps_sets))
print(maps_sets)


#loop over groups to `add' forming direct product of previous direct product and new group
for i in maps_sets[1:]:
    #make list of `combined maps' defining this direct product and fill it up
    new_combo_maps=[]
    for first in combo_maps:
        for second in i:
            combo = combine_maps(first,second,32)
            new_combo_maps.append(combo)
    #update `current' direct product
    combo_maps=new_combo_maps.copy()




#save final direct product of mappings
combo_maps_array = np.array(combo_maps)
np.save('combined_maps.npy',combo_maps_array)

###############################
