#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Feb 3rd 2022

This script takes a list of N atomic environemnts (xyz files) and uses a kde estimate using the (weighted by vdw raidus)
atomic positions to generate a density over a 3d grid. For each kde density the resulting cubical complex is used
to compute the persistance diagram using the gudhi library. The set of N resulting persitence diagrams is saved to a
.npy file.

@author: Jack Doyle
"""

import numpy as np
import os
from sklearn.neighbors import KernelDensity
import gudhi as gd
import argparse


#we use a command line interface
parser = argparse.ArgumentParser()

parser.add_argument('-i', help = 'Input directory, all xyz files in directory are read into program', required = True)

parser.add_argument('-o', help = 'Output directory to save npy files to', required = True)

parser.add_argument('-N', type = int,default=100, help = 'Controls the resolution of kde used for persistent homology calculation.\
                                                The kde is calculated over an NxNxN grid')
parser.add_argument('-b', type=float,default=1.0, help = 'The bandwidth of the kernel used in the density estimation')
parser.add_argument('-k',choices=['gaussian', 'tophat', 'epanechnikov','exponential', 'linear', 'cosine'],
                    default = 'gaussian',help = "kernel used in density estimation options are:\
                    ‘gaussian’, ‘tophat’, ‘epanechnikov’,‘exponential’, ‘linear’, ‘cosine'")
parser.add_argument('-w', action = 'store_true', help = 'If true use van der waals radius to weight \
                    atom positions when constructing kde')
parser.add_argument('-p', action = 'store_true', help = 'If true apply periodic boundary conditions\
                                                to persistent homology calculation')

def main():
    args = parser.parse_args()
    #input dir
    path_to_xyz_dir = args.i
    #output dir
    out_dir = args.o
    #dimension of kde grid
    resolution = args.N
    #kde bandwidth
    bndwidth = args.b
    #kde kernel
    kernel = args.k
    #use weighted points, bool
    weighted = args.w
    #apply perioidic boundary conditions, bool
    pbc = args.p
    all_symbols, all_strucs = read_xyz(path_to_xyz_dir)
    file_dict = get_struc_keys(path_to_xyz_dir)
    if weighted:
        #convert symbols to weights using van der waals radii
        weight_dict = {'C' : 1, 'H' : 110/170, 'N' : 155/170, 'F' : 147/170}
        all_weights = [[weight_dict[item] for item in symbol_list] for symbol_list in all_symbols]
    else:
        all_weights = [None for _ in all_symbols]
    #find pd for each structure
    for i, (struc, weights) in enumerate(zip(all_strucs, all_weights)):
        #fit kde to points on 100x100x100 grid - comes out flattned
        kk = fit_kde_grid(struc, resolution,kernel=kernel, delta=bndwidth,weights = weights)
        #caluclate kde value for each voxel - end up with 99x99x99
        kk3 = sublevel_complex(kk)
        #find pd
        diag = cubical_ph(kk3, periodic=pbc)
        #save output, define file labels for weighted and/or perioidc outputs
        if weighted:
            weight_label = 'w'
        else:
            weight_label = 'u'
        if pbc:
            periodic_label = 'p'
        else:
            periodic_label = 'n'
        #get filename
        name = file_dict[i]
        np.save(os.path.join(out_dir, f"emd_kde_{resolution}_{kernel}_{bndwidth}_{weight_label}_{periodic_label}_{name}"), diag, allow_pickle=True)


def read_xyz(path):
    """
    Reads all N xyz files from directory given in path. For each of the N xyz file we make a list of M chemical symbols
    and Mx3 numpy array of atomic positions. We return a tuple containg a list of symbol lists and a list of numpy arrays

    Args:
        path: str: path to directory with xyz files
    Returns:
        symbs_and_struc: tuple: tuple containg list of sumbol lists and list of atomic positions
    """
    all_strucs = []
    all_symbols = []
    for file in os.listdir(path):
        with open(os.path.join(path, file), 'r') as f:
            #skip the first two (header) lines
            for _ in range(2):
                next(f)
            struc = []
            symbols = []
            for line in f:
                symb, x, y, z = line.split()
                symbols.append(symb)
                struc.append([float(x),float(y),float(z)])
            all_symbols.append(symbols)
            all_strucs.append(np.array(struc).reshape(-1,3))
    symbs_and_strucs = all_symbols, all_strucs
    return symbs_and_strucs

def get_struc_keys(path):
    """
    Get dictionary with filenames in directory with corresponding order. Important so that we know which output
    persistence diagram corresponds to which initial structure

    Args:
        path:str:path to directory
    Returns:
        z:dict:dictionary relating filenames to order in which files are read
    """
    z = {}
    for i,file in enumerate(os.listdir(path)):
        file_prefix = file[:-4]
        z[i] = file_prefix
    return z

def fit_kde_grid(X, N, kernel = 'gaussian', delta = 0.75, weights = None):
    """
    Use  sklearn kernel density estimation to fit 3d function to set of M points, X. Retrun kde over a 3d grid
    with NxNxN points. If weights not None we use weighted points given in the specified list

    Args:
        X: Mx3 numpy array: the sample 3d points
        N: int: grid resolution, we return kde over NxNxN grid
        kernel: str: kernel to use for density estimation can be ‘gaussian’, ‘tophat’, ‘epanechnikov’,
        ‘exponential’, ‘linear’, ‘cosine’
        delta: float: kde bandwidth
        weights: None or list of weights, length M: weights for weighted points
    Returns:
        kk: NxNxN numpy array: kde over 3d grid
    """
    #fit kde
    kde = KernelDensity(kernel = kernel, bandwidth = delta).fit(X, sample_weight = weights)
    #define bounds of grid
    x_0 = X[:,0].min()
    x_1 = X[:,0].max()
    y_0 = X[:,1].min()
    y_1 = X[:,1].max()
    z_0 = X[:,2].min()
    z_1 = X[:,2].max()
    #define grid
    x = np.linspace(x_0, x_1, N)
    y = np.linspace(y_0, y_1, N)
    z = np.linspace(z_0, z_1, N)
    xx, yy, zz = np.meshgrid(x,y,z)
    grid = np.vstack([xx.ravel(), yy.ravel(), zz.ravel()]).T
    #sample kde at grids
    log_dens = kde.score_samples(grid)
    #we take log likliehood so need to exponentiate
    kk = np.exp(log_dens)
    #kk is returned in flattned state
    return kk.reshape(N,N,N)

def sublevel_complex(fnx, levelsize=0.1, perseus=False):
    """
    Not my function, taken from https://gitlab.com/delta-topology-public/deltapersistence/-/tree/master/deltapersistence

    Construct a sublevel filtered cubical complex from a function.

    Converts from a mesh of function values to a mesh of filtration
    times at which a function value occurs. Can be passed directly to
    GUDHI with gudhi.cubical_complex(top_dimensional_cells=cmplx).

    Parameters
    ----------
    fnx : numpy ndarray
        Values of a function on a coordinate mesh.
    levelsize : float, optional
        Resolution at which to filter. Default is 0.1 which is
        reasonable for most uses. Changing this value does not
        typically change performance of persistent homology
        computations.
    perseus : bool, optional.
        If true convert all filtration values to integers. Default is
        False. GUDHI can compute persistent homology with float-valued
        filtrations, while Perseus requires integer filtration values.

    Returns
    -------
    cmplx : numpy ndarray
        The filtration values of the cubical complex.

    """

    if perseus:
        fnx = np.int_(np.floor(fnx/levelsize))
    dimn = fnx.ndim
    cmplx = np.zeros(tuple(np.array(fnx.shape)-1))
    nbrs = _array_neighbors(dimn)
    it = np.nditer(cmplx, flags=['multi_index'])
    while not it.finished:
        mi = np.array(it.multi_index)
        cubemax = fnx[it.multi_index]
        for i in nbrs:
            cubemax = max(cubemax,fnx[tuple(mi+i)])
        cmplx[it.multi_index] = cubemax
        it.iternext()
    return cmplx

def _array_neighbors(n):
    """
    Not my function, taken from https://gitlab.com/delta-topology-public/deltapersistence/-/tree/master/deltapersistence

    Creates all binary tuples of length n.

    The sublevel complex needs to check the value on a cube against all
    of its neighbors, that is, all elements of the array at position
    i+1 in some component(s).

    Parameters
    ----------
    n : int
        Dimension of array.

    Returns
    -------
    nbrs : list
        All binary tuples of length n.

    Examples
    --------
    sublevelpersistence._array_neighbors(2)
    [[0, 0], [0, 1], [1, 0], [1, 1]]

    """

    dimn = str(n)
    nbrs = []
    for i in range(0,2**n):
        nbrs.append(list(format(i, '0'+dimn+'b')))
        nbrs[i] = [int(j) for j in nbrs[i]]
    return nbrs

def cubical_ph(sblv_sets, periodic = False):
    """
    Calculate sublevel set peristent homology using cubical complexes. Can accept perioidc boundary conditions.

    Args:
        sblv_sets: NxNxN numpy array of filtration value
        periodic:bool: whether to apply perioidc boundary conditons
    Returns:
        diag: px2 numpy array: diagram has p points, second item is tuple format (dim, (birth, death)
    """
    if periodic:
        cub = gd.PeriodicCubicalComplex( top_dimensional_cells = sblv_sets, periodic_dimensions = [1,1,1])
    else:
        cub = gd.CubicalComplex( top_dimensional_cells = sblv_sets)
    diag = np.array(cub.persistence(), dtype=object) #convenient to have pd as numpy array
    return diag

if __name__ == '__main__':
    main()