import math
from abaqus import *
from abaqusConstants import *
from interaction import *
import regionToolset
from SpindleAssembly.AddComponents import return_assembly


def create_RP(**kwargs):
    """
    Create a reference point

    :param kwargs: object

    :return: reference point object
    """
    a = return_assembly(**kwargs)

    # Define position of a reference point
    if 'xpos' in kwargs and 'ypos' in kwargs and 'zpos' in kwargs:
        x = kwargs['xpos']
        y = kwargs['ypos']
        z = kwargs['zpos']
        RP = a.ReferencePoint(point=(x, y, z))
    elif 'verts' in kwargs:
        RP = a.ReferencePoint(point=kwargs['verts'][1])
    else:
        raise ValueError('Either x,y,z or a valid vertices should be specified to'
                         'create a reference point')
    return RP


def pick_region(verts_index, regionType, collectionName, position, **kwargs):
    """
    Pick a region to assign connection

    :param regionType: 'vertice', 'edge', 'centrosome' -> defines the type of the picked region

    :param collectionName: name of the entity specified by the region

    :param position: 'centrosome-right', 'centrosome-left' -> defines to which pole the region belongs

    :param kwargs: object

    :return: (object) region
    """
    a = return_assembly(**kwargs)
    # Create region from specified vertices
    if regionType == 'vertice':
        collection = kwargs[collectionName]
        if verts_index == 0:
            verts = [a.instances[vI].vertices[verts_index]
                     for vI in a.instances.keys() if vI in collection]
            region = verts
        else:
            verts = [a.instances[vI].vertices
                     for vI in a.instances.keys() if vI in collection]
            region = verts
    # Create region from specified edges
    elif regionType == 'edge':
        # Currently available for ipMTs only
        edges = [a.instances[eI].edges
                 for eI in a.instances.keys() if  ('ipMT' in eI and 'connector' not in eI)]
        region = edges
    # Create centrosome region
    elif regionType == 'centrosome':
        if position == 'centrosome-right':
            s1 = a.instances[position].faces
            side1Faces1 = s1.getSequenceFromMask(mask=('[#1 ]',), )
            region = a.Surface(side1Faces=side1Faces1, name='RightCentrosomeCouplingRegion')
        elif position == 'centrosome-left':
            s1 = a.instances[position].faces
            side1Faces1 = s1.getSequenceFromMask(mask=('[#1 ]',), )
            region = a.Surface(side1Faces=side1Faces1, name='LeftCentrosomeCouplingRegion')
    else:
        raise ValueError('Only vertices and edges are currently supported')

    return region


def sum_regions(verts_index, regionType, collectionName, separate='True', **kwargs):
    """
    Split a single picked region into a collection of sub-regions

    :param regionType: 'vertice', 'edge', 'centrosome' -> defines the type of the picked region

    :param collectionName: name of the entity specified by the region

    :param separate: 'True' -> separate combined region into right and left sub-regions

    :param kwargs: opbject

    :return: either combined_region or combined_region_right and combined_region_left
    """
    # select a master region
    region = pick_region(verts_index, regionType, collectionName, 'False', **kwargs)
    # separate sub-regions into right and left ones
    if separate == 'True':
        combined_region_right = region[0]
        combined_region_left = region[1]
        for element in region[2::2]:
            combined_region_right += element
        for element in region[1::2]:
            combined_region_left += element
        return combined_region_right, combined_region_left
    # split master region into a collection of similar sub-regions
    else:
        combined_region = region[0]
        for element in region:
            combined_region += element
        return combined_region


def coupling_constraint(region1, region2, influenceRadius, couplingType,
                        weightingMethod, name, **kwargs):
    """
    Create a coupling constraint between two regions

    :param region1: Master region

    :type region1: object

    :param region2: Slave region

    :type region2: object

    :param influenceRadius: The influence of the master region will be distributed throughout the subregion in the
            slave region defined by this radius

    :type influenceRadius: float

    :param couplingType: Type of coupling is DISTRIBUTED ot STRUCTURAL

    :param weightingMethod: Method of averaging the parameters of the coupling

    :type weightingMethod: str

    :param name: Name of the coupling region

    :type name: str

    :param kwargs: model parameters

    :type kwargs: dict

    :return: Null

    :rtype: Null
    """
    modelname = kwargs['modelname']
    mdb.models[modelname].Coupling(name=name,
                                   controlPoint=region1,
                                   surface=region2,
                                   influenceRadius=influenceRadius,
                                   couplingType=couplingType,
                                   weightingMethod=weightingMethod,
                                   localCsys=None,
                                   u1=ON, u2=ON, u3=ON,
                                   ur1=ON, ur2=ON, ur3=ON)


def attach_spring(region, dof, name, springType='Ground', **kwargs):
    """
    Defines a spring object that couples two points or a point to the ground

    :param region: Name of the region to which spring is attached

    :type region: str

    :param dof: Number of DOF associated with the spring.

    :type dof: int

    :param name: name of the spring connection

    :type name: str

    :param springType: Type of the spring. Either Ground or Pair

    :type springType: str

    :param kwargs: model parameters

    :type kwargs: dict

    :return: Null

    :rtype: Null
    """
    modelname = kwargs['modelname']
    if springType == 'Ground':
        springStiffness = kwargs['aMTsSpring']
        mdb.models[modelname].rootAssembly.engineeringFeatures.SpringDashpotToGround(
            name=name, region=region, orientation=None, dof=dof,
            springBehavior=ON, springStiffness=springStiffness,
            dashpotBehavior=OFF,
            dashpotCoefficient=0.0)
    elif springType == 'Pair':
        springStiffness = kwargs['groundSpring']
        mdb.models[modelname].rootAssembly. \
            engineeringFeatures. \
            TwoPointSpringDashpot(name=name,
                                  regionPairs=region,
                                  axis=NODAL_LINE, springBehavior=ON,
                                  springStiffness=springStiffness,
                                  dashpotBehavior=OFF,
                                  dashpotCoefficient=0.0)
    else:
        raise ValueError('Only Ground spring or Pair spring is available')


def couple_nearest_aMTs(i, **kwargs):
    """
    Create a spring-based coupling of pairs of astral microtubules

    :param i: Number of aMT

    :type i: int

    :param kwargs: model parameters

    :type kwargs: dict

    :return: Null

    :rtype: Null
    """
    a = return_assembly(**kwargs)
    # Create a reference point attached to each aMT
    aMTnames = kwargs['aMTnames']
    v = a.instances[aMTnames[i]].vertices
    kwargs.update({'verts': v})
    RPaMT = create_RP(**kwargs)
    r = a.referencePoints
    regionRP = a.Set(referencePoints=(r[RPaMT.id],),
                     name='aMTsRP' + str(i))
    # Select individual edges on aMTs
    edge = a.instances[aMTnames[i]].edges
    regionMT = a.Set(edges=edge, name='aMTconnectRegionRight' + str(i))
    influenceRadius = WHOLE_SURFACE
    couplingType = DISTRIBUTING
    weightingMethod = UNIFORM
    if i % 2 == 0:
        name = 'RightAMTcoupling' + str(i)
    else:
        name = 'LeftAMTcoupling' + str(i)
    coupling_constraint(regionRP, regionMT, influenceRadius,
                        couplingType, weightingMethod, name, **kwargs)
    refPoint = (r[RPaMT.id],)
    return refPoint, regionRP


def find_nearest(position, regions):
    """
    Find the astral microtubules that have the closest positions

    :param position: (x, y, z) position of the aMT growing end

    :type position: tuple

    :param regions: Regions that contain nearest aMT ends

    :type regions: list of objects

    :return: Null

    :rtype: Null
    """
    d = [math.sqrt((position[i - 1][0] - position[i][0]) ** 2 +
                    (position[i - 1][1] - position[i][1]) ** 2 +
                    (position[i - 1][2] - position[i][2]) ** 2)
                    for i in range(1, len(position))]
    regionsCombined = [list(a) for a in zip(d, regions)]
    regionsCombined.sort(key=lambda x: x[0], reverse=False)
    region = []
    for i in range(len(regionsCombined) - 1):
        # Create region of connections
        region.append((regionsCombined[i][1],
                       regionsCombined[i + 1][1]))
    return region
