"""

5Th July 2022

Jack Doyle

This script takes a pkl file which holds all molecules whose centroid is contained in the primitive cell of a given
crystal structure. We expand out the cell to contain N molecules closest to the centre of the fragment, if there are less than
N atoms in the fragment we expand into a bigger fragment. From there we obtain a set of properties for each of the N molecules
and write these to a file. Typically we will pull the molecular centroid and some vector which quantifies the direction of
the molecule i.e. principle axis of inertia or dipole moment.

Here we alter the script so that a molecule can be instantiated with built-in properties

29 Jul 2022 - changed the inertia vector method to loop through the data correctly, also dealt with the case of flipping
the inertia vector when there are an equal mass on each side of the decison boundary

"""

import argparse
import ast

import numpy as np
import pickle
import scipy.spatial
from ast import literal_eval




def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', required=True, help = 'Name of primitve cell pkl file to be read')
    parser.add_argument('-include_only', default = 'all', nargs = "*", help = "list of atom types to include when building"
                                                                              "crystal structure for example if we use"
                                                                              "C N we only use carbon and nitrogen atoms"
                                                                              "of the underlying molecules so exclude other"
                                                                              "atoms in the calculation of cnentroids and vectors.")
    parser.add_argument('-n', default=50, type = int, help = "Number of molecules to return the properties of ")
    parser.add_argument('-R_list', nargs = '*' ,default=[2,2,2], type = int,help = 'Supercell ((-R1, -R2, -R3) to (R1, R2, R3))'
                                                            ' size to be used in molecule searching')
    parser.add_argument('-p', default = ['centroid'], nargs = '*', help = 'Mol propeties to calculate; must be properties'
                                                                          'of Molecule class. If function kwargs are given'
                                                                          'use syntax property:argument:value'
                                                                          'e.g. centroid:weighted:True')



    args = parser.parse_args()
    filename = args.f
    print(filename)
    property_list = []
    option_list = []
    for item in args.p:
        # seperate any option keywords form attribute keyword
        splitem = item.split(':')
        property_list.append(splitem[0])
        if len(splitem) == 1:
            option_list.append(None)
        elif len(splitem)%2 == 1:
            option = {}
            for i in range(1, len(splitem) - 1, 2):
                option[splitem[i]] = splitem[i+1]
            option_list.append(option)
        else:
            print('Warning - Unpaired Argument')
            option = {}
            #miss out last item
            for i in range(1, len(splitem) - 2, 2):
                option[splitem[i]] = splitem[i+1]
            #unpaired_argument is given None label
            option[splitem[len(splitem)]] = None
            option_list.append(option)


    with open(filename, 'rb') as f:
        input_cell = pickle.load(f)


    if args.include_only == 'all':
        crys = Crystal(input_cell)
    else:
        new_cell = RemoveAtoms(input_cell, args.include_only)
        crys = Crystal(new_cell)

    if len(args.R_list) == 1:
        R = args.R_list[0]
        mol_list = target_mols(crys, args.n, R)
    elif len(args.R_list) == 3:
        mol_list = target_mols(crys, args.n, R_list = args.R_list)
    else:
        print(f"Unexpected R-list has length {len(args.R_list)}")
    output_properties = attr_from_mols(mol_list, attr=property_list, options=option_list)
    save_file(filename, output_properties)


# define atom, molecule and crystal objects

class Atom:
    """Has symbol and xyz co-ordinates. Coordinates are given as list"""
    def __init__(self, symbol, coords):
        self.symbol = symbol
        self.all_coords = np.array(coords)
        self.x = coords[0]
        self.y = coords[1]
        self.z = coords[2]
        self._mass_dict = {'C' : 12, 'F' : 9, 'N' : 7, 'H' : 1, 'O' : 16, 'Cl' : 35.5, 'S' : 32, 'Br' : 80, 'I' : 127}
        self.mass = self._mass_dict[self.symbol]

class Molecule:
    """Is composed of atoms. Instantiated by list of tupels (symbol, coords)"""

    def __init__(self, atom_tups, **props):
        # unpack symbol and coords for each atom in provided atom tuples
        self.atoms = [Atom(item[0], item[1]) for item in atom_tups]
        self._raw_props = props
        for key, value in zip(self._raw_props.keys(), self._raw_props.values()):
            setattr(self, key, value)
        self.c_atoms = []
        for a in self.atoms:
            if a.symbol == 'C':
                self.c_atoms.append(a)
        self._atom_tups = atom_tups
        self.internal_coods = [(atom.symbol, atom.all_coords - self.centroid(weighted=True)) for atom in self.atoms]
        self.all_coords = [(atom.symbol, atom.all_coords) for atom in self.atoms]
        self.internal_coods_c_only = [(atom.symbol, atom.all_coords - self.centroid_c_only()) for atom in self.c_atoms]
        self.all_coords_c_only = [(atom.symbol, atom.all_coords) for atom in self.c_atoms]
        # combined mass_dict
        self._mass_dict = {}
        for atom in self.atoms:
            self._mass_dict.update(atom._mass_dict)

    def centroid(self, weighted = False):
        if weighted:
            return np.sum(np.array([item.mass*item.all_coords for item in self.atoms]), axis = 0)/np.sum(np.array([item.mass for item in self.atoms]))
        else:
            return np.mean(np.array([item.all_coords for item in self.atoms]), axis = 0)

    def centroid_c_only(self):
        return np.mean(np.array([item.all_coords for item in self.c_atoms]), axis = 0)

    def inertia_tensor(self, pos, c_only = False):
        I = np.zeros(shape = (3,3))
        if c_only:
            atoms_to_use = self.c_atoms
        else:
            atoms_to_use = self.atoms
        for i in range(3):
            for j in range(3):
                # vector to point where i. tensor is calculated
                for item in atoms_to_use:
                    R = pos -item.all_coords
                    I[i][j] += item.mass*((np.linalg.norm(R))**2*(i == j) - R[i]*R[j])
        return I

    def inertia_axes(self, axes_list=[0], flip = True, c_only = False, flip_with_c_only = False):
        """Returns eigenvector of interia tensor centred at cenroid. Axes_list is denotes which
        eigenvectors to return (highest eigenvalue to lowest eigenvalue). e.g. [0,1] means return the
        first and second inertia axes. By default we take the eigenvector which corresponds to the biggest
        eigenvalue

        If parameter flip is true, all eigenvectors are flipped such that they face in the direction of increasing
        atomic mass. i.e. if a plane is placed perpendicular to the eigenvector located at the centre of mass and
        the total mass of atoms on the side of the plane that faces with the vector is less than the mass of atoms
        on the other side, the sign of the vector is changed (else leave unchanged).

        If c_only is true we find the inertia_tensor using carbon atoms only - this is useful so that we get similar
        eigenvectors for molecules with different subsistiutns but the same skeleton. However for the flipping algorithm
        we still use all atoms as the flipping algorithm relies on the asymmetric mass distrubution that you get when
        all subsistuesnts are included. This can be switched off with flip_with_c_only = True

        """
        if type(axes_list) == str:
            axes_list = literal_eval(axes_list)
        # by deafult use inertia tensor centred at centorid
        if c_only:
            vals, vecs = np.linalg.eig(self.inertia_tensor(self.centroid_c_only(), c_only = True))
        else:
            vals, vecs = np.linalg.eig(self.inertia_tensor(self.centroid(weighted=True)))
        sorted_all = sorted(zip(vals, vecs), key=lambda x: x[0], reverse=True)
        sorted_vecs = [item[1] for item in sorted_all]
        if flip:
            flipped_vecs = []
            for vec in sorted_vecs:
                flipped_vecs.append(self._flip_vector(vec, c_only=flip_with_c_only))
            return [flipped_vecs[a] for a in axes_list]
        else:
            return [sorted_vecs[a] for a in axes_list]

    def best_inertia_vector(self, flip = True, c_only = True, flip_with_c_only = False, N_vecs_to_return = 1):
        """
        Compute all three eigenvectors of the inertia tensor with vector flipping and use of carbon atoms only optional.
        For each vector we compute the sum of the magnitudes of all dot products with each (normed) atomic position vector. The vector
        with the biggest sum, that is, the vector which is the least orthoognal to the body of the molecule, is computed.
        Optionaly we can return more than 1 vector i.e. get the second (and third) best as well.
        """
        N_vecs_to_return = int(N_vecs_to_return)
        u, v, w = self.inertia_axes(axes_list = [0,1,2], flip = flip, c_only = c_only, flip_with_c_only = flip_with_c_only)
        coords_to_use = [item[1] for item in self.internal_coods]
        sum_u = np.sum([abs(np.dot(c/np.linalg.norm(c), u)) for c in coords_to_use])
        sum_v = np.sum([abs(np.dot(c/np.linalg.norm(c), v)) for c in coords_to_use])
        sum_w = np.sum([abs(np.dot(c/np.linalg.norm(c), w)) for c in coords_to_use])
        # need to negate sums as argsort sorts smallest to biggest
        ixs = np.argsort([-sum_u, -sum_v, -sum_w])[:N_vecs_to_return]
        best_vec = [[u, v, w][ix] for ix in ixs]
        best_vec = np.array(best_vec).flatten()
        return best_vec

    def inertia_eignv(self, axes_list=[0], c_only = False):
        """Returns eigenvalue of interia tensor centred at cenroid. Axes_list is denotes which
        eigenvectors to return (highest eigenvalue to lowest eigenvalue). e.g. [0,1] means return the
        first and second inertia axes.
        """
        if type(axes_list) == str:
            axes_list = literal_eval(axes_list)
        # by deafult use inertia tensor centred at centorid
        if c_only:
            vals, vecs = np.linalg.eig(self.inertia_tensor(self.centroid_c_only(), c_only = True))
        else:
            vals, vecs = np.linalg.eig(self.inertia_tensor(self.centroid(weighted=True)))
        sorted_all = sorted(zip(vals, vecs), key=lambda x: x[0], reverse=True)
        sorted_vals = np.array([item[0] for item in sorted_all])
        return sorted_vals[axes_list]

    def norm_and_flip(self, norm=True, flip=True):
        if hasattr(self, 'x_v'):
            if norm:
                self.x_v = np.array(self.x_v)/np.linalg.norm(np.array(self.x_v))
            if flip:
                self.x_v = self._flip_vector(np.array(self.x_v))
        if hasattr(self, 'y_v'):
            if norm:
                self.y_v = np.array(self.y_v)/np.linalg.norm(np.array(self.y_v))
            if flip:
                self.y_v = self._flip_vector(np.array(self.y_v))
        if hasattr(self, 'z_v'):
            if norm:
                self.z_v = np.array(self.z_v)/np.linalg.norm(np.array(self.z_v))
            if flip:
                self.z_v = self._flip_vector(np.array(self.z_v))
        self._raw_props = {'x_v' : self.x_v, 'y_v' : self.y_v, 'z_v' : self.z_v}

    def size(self,use_coords = 'xyz', c_only = False):
        """
        Find the "radius" of the molecule. By default we find R = sqrt(delta_x**2 + delta_y**2 + delta_z**2) but
        by altering use_coords we can find just delta_x or sqrt(delta_x**2 + delta_z**2) for example. This gives a
        paramter that approximatly describes how big the molecule is. If c_only is True we find delta using only carbon
        atoms - so we can get an idea of the length scale of the carbon atoms.
        """
        use_coords = [item for item in use_coords]
        # chosee the right xyz coordiantes; if c_only we are only interested in carbon atoms
        if c_only:
            coords_to_use = np.array([item.all_coords for item in self.c_atoms])
        else:
            coords_to_use = np.array([item.all_coords for item in self.atoms])
        # ensure that use_coords is valid
        assert set(use_coords).issubset({'x', 'y', 'z'}), f"{use_coords} is not a subset of {'x', 'y', 'z'}"
        # we put each delta_x,y,z into list and compute square root of sum of elements
        all_spacing = []
        # use set to remove duplicates e.g. ['x', 'x', 'x']
        for c in set(use_coords):
            if c  == 'x':
                all_x = [item[0] for item in coords_to_use]
                all_spacing.append((max(all_x) - min(all_x))**2)
            elif c  == 'y':
                all_y = [item[1] for item in coords_to_use]
                all_spacing.append((max(all_y) - min(all_y))**2)
            elif c == 'z':
                all_z = [item[2] for item in coords_to_use]
                all_spacing.append((max(all_z) - min(all_z))**2)
            else:
                print(f"error - unexpected coord {c}")

        R = np.sqrt(sum(all_spacing))
        return R





    def _mass_moments(self, coords):
        """
        For a set of coordinates of form [(symbol, (x,y,z))] find themass moments, that is the terms of the
        form (m2 - m1)(r2 - r1)
        """
        mass_moms = []
        for p1 in coords:
            for p2 in coords:
                m1 = self._mass_dict[p1[0]]
                m2 = self._mass_dict[p2[0]]
                delta_m = m2 - m1
                r1 = p1[1]/np.linalg.norm(p1[1])
                r2 = p2[1]/np.linalg.norm(p2[1])
                d_ij = r2 - r1
                mass_moms.append(delta_m*d_ij)
        return mass_moms


    def _flip_vector(self, v, c_only = False):
        # take dot product of every atomic position vector (starting at com) with vector, v - will tell you which
        # side of plane atom is on
        if c_only:
            coords_to_use = self.internal_coods_c_only
        else:
            coords_to_use = self.internal_coods
        if np.min([np.linalg.norm(item[1]) for item in coords_to_use]) == 0:
            print('bad')
        dots = np.array([np.array((self._mass_dict[item[0]], np.dot(item[1]/np.linalg.norm(item[1]), v))) for item in coords_to_use])
        frac_mass_with_vector = np.sum(dots[:,0][np.where(dots[:,1] >= 0)[0]])/np.sum(dots[:,0])
        if frac_mass_with_vector < 0.5:
            return -v
        elif frac_mass_with_vector > 0.5:
            return v
        else:
            # if there is equal mass on each sid of vector decide which way to point from the sum of mass moments on each
            # side
            coords_with = [coords_to_use[ix] for ix in list(np.where(dots[:,1] >= 0)[0])]
            coords_against = [coords_to_use[ix] for ix in list(np.where(dots[:,1] < 0)[0])]
            decision_with = sum([abs(np.dot(m, v)) for m in self._mass_moments(coords_with)])
            decision_against = sum([abs(np.dot(m, v)) for m in self._mass_moments(coords_against)])
            if decision_against > decision_with:
                return - v
            else:
                return v

    def mass_frac(self, v):
        # take dot product of every atomic position vector (starting at com) with vector, v - will tell you which
        # side of plane atom is on
        coords_to_use = self.internal_coods
        dots = np.array([np.array((self._mass_dict[item[0]], np.dot(item[1]/np.linalg.norm(item[1]), v))) for item in coords_to_use])
        frac_mass_with_vector = np.sum(dots[:,0][np.where(dots[:,1] >= 0)[0]])/np.sum(dots[:,0])
        return frac_mass_with_vector

    def CN_vec(self, norm = True):
        """This method is designed for benzyliated fluornalines only. Finds the vector between the central C and N atoms
        in direction of N. There is only one N atom which is easily found. The C atom is found based on the identies of
        the nearest neigbours as found with a KD tree. The expected neigbours are {'C', 'H', 'N'}
        An error is returned if no atom with the right connectivity is found
        or there is more than one such atoms.
        Normalisation optional"""
        # first find N atom
        n_list = [atom for atom in self.atoms if atom.symbol == 'N' ]
        #ensure there is only one N atom
        assert len(n_list) == 1, "There should only be one Nitrogen atom for molecules used in this method"
        target_n = n_list[0]
        # build kd tree - we assume in this case that the bonded atoms are the same as the nearest neigbours in Euclidean space
        tree = scipy.spatial.KDTree(np.array([item[1] for item in self.all_coords]).reshape(-1,3))
        # make list of all atoms with 3 nearest neigbours C, H, N - all C atoms should be sp2 here
        c_list = []
        N = []
        for atom in self.atoms:
            if atom.symbol == 'C':
                # query the tree need 4 nearest neighbours as first will be atom itself
                pos_ix = tree.query(np.array(atom.all_coords), k = 4)[1]
                neighbours = [self.all_coords[ix][0] for ix in pos_ix[1:]]
                N.append(neighbours)
                if set(['C', 'H', 'N']).issubset(set(neighbours)):
                    c_list.append(atom)
        # check we only have one atom with this connectivity
        assert len(c_list) == 1, f'There should only be one carbon atom with connectivity CHN for molecules in this method.There are currently{len(c_list)} such atoms.{N} '
        target_c = c_list[0]
        if norm:
            cn_vector = ((np.array(target_n.all_coords) - np.array(target_c.all_coords))
                         /np.linalg.norm((np.array(target_n.all_coords) - np.array(target_c.all_coords))))
        else:
            cn_vector = (np.array(target_n.all_coords) - np.array(target_c.all_coords))
        return cn_vector

    def _5A_vecs(self):
        """This method is specificaly designed for a specfic molecule as a test case to compare with the inertia
         eigenvectors. The molecule is labelled as 5A and the structure can be seen in
         https://eprints.soton.ac.uk/412025/1/structure_prediction_azapentacene.pdf.
         We find the vector that goes between any two nitrogen atoms parrelel to the long axis of the molecule"""
        # find all nitrogen atoms in the molecule
        N_atoms = [atom.all_coords for atom in self.atoms if atom.symbol == 'N']
        for i, xyz1 in enumerate(N_atoms):
            for j, xyz2 in enumerate(N_atoms):
                # we know approximatly the distance between the pairs of nitrogen atoms we are interested in
                if scipy.spatial.distance.euclidean(xyz1, xyz2) > 3.9 and scipy.spatial.distance.euclidean(xyz1, xyz2) < 5:
                    # we only need one example, as the (3) vectors will all be approximalty the same
                    return abs(xyz1 - xyz2)
        return 'nan'

    def _norm_vec(self):
        """This method is specificaly designed for planar molecules. We assume all C atoms lie in the plane of the molecule.
        The method returns a vector normal to the plane of the molecule by finding the cross product using the vector between
        (respectively) C atoms 2 & 1 and 3 & 1"""
        c_atoms = np.array([atom.all_coords for atom in self.c_atoms]).reshape(-1,3)
        diff_1 = c_atoms[1] - c_atoms[0]
        diff_2 = c_atoms[2] - c_atoms[0]
        n = np.cross(diff_1, diff_2)
        return n/np.linalg.norm(n)


class Crystal:
    """Returns molecules in the infinite crystal structure given molecules in primitive cell.
    Primitve cell is given as list containg firstly two tuples which contain cell lengths and angles, followed by
    lists of tuples which correspond to molecules"""
    def __init__(self, input_params):
        self.a, self.b, self.c = input_params[0][1]
        self.alpha, self.beta, self.gamma = [np.pi/180*item for item in input_params[1][1]]
        # turn molecule tuples into molecule object
        #the rest of the input is list of dicotries - seperate the atoms and any given properties
        self.primitive_mols = []
        for mol in input_params[2:]:
            if type(mol) == dict:
                atoms = mol.pop('atoms')
                #pop also got rid of the atoms
                props = mol
                if props == {}:
                    # here we have no additional molecular properties
                    self.primitive_mols.append(Molecule(atoms))
                else:
                    # here we do so include them in the molecule object
                    self.primitive_mols.append(Molecule(atoms, **props))
                for mol in self.primitive_mols:
                    mol.norm_and_flip(norm=True, flip=True)
            else:
                self.primitive_mols.append(Molecule(mol))
        self.cell_volume = self.a*self.b*self.c*np.sqrt(1 - np.cos(self.alpha)**2
                                                        - np.cos(self.beta)**2 - np.cos(self.gamma)**2
                                                        + 2*np.cos(self.alpha)*np.cos(self.beta)*np.cos(self.gamma))
        # define cell vectors
        self.v1 = self.a*np.array([1, 0, 0])
        self.v2 = self.b*np.array([np.cos(self.gamma), np.sin(self.gamma), 0])
        self.v3 = self.c*np.array([np.cos(self.beta), (np.cos(self.alpha) -
                                                       np.cos(self.beta)*np.cos(self.gamma))/np.sin(self.gamma),
                                   self.cell_volume/(self.a*self.b*self.c*np.sin(self.gamma))])


    def AverageSize(self, how = 'radius', c_only = False):
        """
        Compute the average "size" of the molecules in the crystal, see the "size" method in the molecule class.
        As we as chosing whether we can use carbon atoms or not when computing the size we can choose between 3 methods
        to compare sizes. 1) 'radius' : just compute sqrt(d_x**2 + d_y**2 + d_z**2), 2) 'max' : compute max(d_x, d_y, d_z)
        3) 'min' : compute min(d_x, d_y, d_z).

        The average of the mol size is taken over all primitive mols in the crystal

        """
        if how == 'radius':
            # radius is deafualt option in size method
            mol_sizes = [mol.size(c_only = c_only) for mol in self.primitive_mols]
            av_size = np.mean(mol_sizes)
        elif how == 'max':
            mol_sizes = [max(mol.size(use_coords = ['x'], c_only = c_only), mol.size(use_coords = ['y'], c_only = c_only),
                             mol.size(use_coords = ['z'], c_only = c_only)) for mol in self.primitive_mols]
            av_size = np.mean(mol_sizes)
        elif how == 'min':
            mol_sizes = [min(mol.size(use_coords = ['x'], c_only = c_only), mol.size(use_coords = ['y'], c_only = c_only),
                             mol.size(use_coords = ['z'], c_only = c_only)) for mol in self.primitive_mols]
            av_size = np.mean(mol_sizes)
        else:
            print(f"Unknown Mehtod {how}")
            av_size = None
        return av_size

    def BuildNewCell(self, l, m, n):
        """ Returns primitive cell translated by l*v1 + m*v2 + n*v3"""
        trans_vec = l*self.v1 + m*self.v2 + n*self.v3
        new_mols = []
        for mol in self.primitive_mols:
            #translate the atom positions but keep the symbol
            new_atoms = [(atom[0], atom[1] + trans_vec) for atom in mol._atom_tups]
            # these will be unchanged
            new_props = mol._raw_props
            new_mols.append(Molecule(new_atoms, **new_props))
        return new_mols


    def BuildSupercell(self, L, M, N):
        """Return primitve cell translated by all tuples of cell vectors from (-L, -M, -N) to (L, M, N) """
        new_mols = []
        for i in range(-L + 1, L):
            for j in range(-M + 1, M):
                for k in range(-N + 1, N):
                    new_mols.extend(self.BuildNewCell(i,j,k))
        return new_mols


def target_mols(crys, N, R = 2, R_list = []):
    """ For given crystal object find all molecules which are contained in the  supercell defined by all tuples of
    cell vectors from (-R, -R, -R) to (R, R, R) and return the N molecules which are closest to the origin. If there are
    less than N molecules in the fragment we generate a bigger fragment until we have at least N molecules.

    The cells can be defined from (R_list[0], R_list[1], R_list[2]) as well which overwrites the R argument
    """
    if R_list == []:
        frag = crys.BuildSupercell(R, R, R)
        if len(frag) < N:
            print("crystal fragment too small generating new fragment...")
            count = 1
            while len(frag) < N:
                frag = crys.BuildSupercell(R + count, R + count, R + count)
                count += 1
    else:
        frag = crys.BuildSupercell(R_list[0], R_list[1], R_list[2])
        if len(frag) < N:
            print("crystal fragment too small generating new fragment...")
            while len(frag) < N:
                # increment fragment sizes, unless size is 1
                # so go from 2x2x1 to 3x3x1
                # this is important as we are interested in the cases where the periodicity on one axis is restricted.
                new_R_list = []
                for r in R_list:
                    if r != 1:
                        r += 1
                    new_R_list.append(r)
                R_list = new_R_list
                frag = crys.BuildSupercell(R_list[0], R_list[1], R_list[2])
    # find centroid distances from origin
    dist_list = [scipy.spatial.distance.euclidean(mol.centroid(weighted=True), np.array([0,0,0])) for mol in frag]
    # sort fragment mols by distance
    sorted_mols = [x for _, x in sorted(zip(dist_list, frag), key=lambda x: x[0])]
    # clean mols to remove duplicates, only do this for 2*N mols as this takes time, if less than N remain try again
    # with all mols
    if 2*N < len(sorted_mols):
        mols_out = clean_mols(sorted_mols[:2*N])
        if len(mols_out) < N:
            mols_out = clean_mols(sorted_mols)
    else:
        mols_out = clean_mols(sorted_mols)
    return mols_out[:N]

def clean_mols(mol_list, min_dist = 1.0):
    """
    In some cases the translation of the molecules of the primitive cell can result in duplicate molecules. This
    method removes those molecules that are deemed too close. If the intercentroid distance is less than some
    minimum distance the molecule ,one of these molecules will be removed from the input list of molecules.
    """
    new_mol_list = []
    for i, mol1 in enumerate(mol_list):
        keep_mol = True
        for j, mol2 in enumerate(mol_list):
            # only need to look at every possible pair
            if j > i:
                if abs(scipy.spatial.distance.euclidean(mol1.centroid(weighted = True) , mol2.centroid(weighted=True))) < min_dist:
                    keep_mol = False
        if keep_mol:
            new_mol_list.append(mol1)
    return new_mol_list

def attr_from_mols(target_mols, attr = ['centroid'], options = [None]):
    """ Obtain for each mol in targets, as dictionary of attributes as defined by the Molecule class. A tuple of attributes
    is provided by the user. If a tuple of options ({arg1a : val1a, arg1b:, val1b}, {arg2 : val2}) is provided this must be the same
    length as the tuple of attributes."""
    if len(attr) != len(options):
        print("Mol functions do not match options, ensure attr and options tuples have the same length. Passing"
              "default arguments")
        options = [None for _ in attr]
    props = []
    for mol in target_mols:
        line = {}
        for a, o in zip(attr, options):
            if hasattr(mol, a):
                if callable(getattr(mol, a)):
                    if o is None:
                        func = getattr(mol, a)
                        line[a] = func()
                    else:
                        kwargs = o
                        func = getattr(mol, a)
                        line[a + str(o)] = func(**kwargs)
                        """
                        try:
                            line[a] = func(**kwargs)
                        except TypeError:
                            print("Invalid kwargs. Passing default arguments.")
                            line[a] = func
                        """
                else:
                    # attribute is property of class and not method so does not have options
                    line[a] = getattr(mol,a)
            else:
                print(f"Molecule has no attribute {a}")
                line[a] = 'NaN'
        props.append(line)
    return props

def save_file(name_in, output):
    """Saves list of mol properties to pkl file"""
    #define a name for the output file
    name_out = (name_in[:-4] + '_all_properties.pkl')
    with open(name_out, 'wb') as f:
        pickle.dump(output, f)

def RemoveAtoms(prim_cell, atoms_to_keep):
    """
    For each mol tuple (symb,coords) in the input primitive cell only retain  pairs with symbols
     in atoms_to_keep. Return a new prim cell with the same cell angles/lengths but altered mol dicts.
    """
    # keep the lengths and angles i.e. first two items of prim cell
    new_cell = prim_cell[:2]
    mols_to_change = prim_cell[2:]
    all_new_mols = []
    for mol in mols_to_change:
        new_mol = []
        for at in mol:
            if at[0] in atoms_to_keep:
                new_mol.append(at)
        all_new_mols.append(new_mol)
    new_cell.extend(all_new_mols)
    return new_cell



if __name__ == '__main__':
    main()