
#__author__ = "Chris Maidens"
#__copyright__ = "Copyright (C) 2020 Chris Maidens"
#__license__ = "No license granted"
#__version__ = "0.1"

# https://www.datacamp.com/tutorial/markov-chains-python-tutorial
# https://stackoverflow.com/questions/58048810/building-n-th-order-markovian-transition-matrix-from-a-given-sequence
# https://stats.stackexchange.com/questions/147164/fit-and-evaluate-a-second-order-transition-matrix-markov-process-in-r
# https://stats.stackexchange.com/questions/512221/creating-transition-states-for-a-second-order-markov-chain-for-attribution
import os
import numpy as np

import MAFpt_r_params as rpar

import MAFpt_ATTACK_DB_v2 as ATTACK

YAML_Root = os.environ.get('MAFpt_YAML_ROOT')
YAML_File = os.environ.get('MAFpt_YAML_FILE')
YAML_File="MAFpt_RunParams.yaml"

YAML_Root = os.getcwd()
YAML_File="/MAFpt_ATTACK_DB_TEST_RunParams.yaml"

p_obj = rpar.MAFpt_r_params(YAML_Root + YAML_File)

DOWNLOAD_ATTACK=p_obj.MAFpt_r_read("RUN_DOWNLOAD_ATTACK")
REINDEX_ATTACK=p_obj.MAFpt_r_read("RUN_REINDEX_ATTACK")
ATTACK_LOCAL_FILE_ROOT = p_obj.MAFpt_r_read('RUN_ATTACK_LOCAL_FILE_ROOT')
ATTACK_TAXII_SERVER = p_obj.MAFpt_r_read('RUN_ATTACK_TAXII_SERVER')
ATTACK_LOCAL_COPY = p_obj.MAFpt_r_read('RUN_ATTACK_LOCAL_COPY')
ATTACK_CVE_SEARCH = p_obj.MAFpt_r_read('RUN_ATTACK_CVE_SEARCH')
ATTACK_MAIN_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_MAIN_INDEX')
ATTACK_SUB_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_SUB_INDEX')
ATTACK_CVE_REF_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_CVE_REF_INDEX')
ATTACK_TTP_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_TTP_INDEX')
ATTACK_TACTIC_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_TACTIC_INDEX')
ATTACK_TECH_TO_TACTIC_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_TECH_TO_TACTIC_INDEX')
ATTACK_REL_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_REL_INDEX')    
ATTACK_TACTIC_BIN_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_TACTIC_BIN_INDEX')   
ATTACK_TTP_BIN_INDEX = p_obj.MAFpt_r_read('RUN_ATTACK_TTP_BIN_INDEX') 

ATTACK_obj = ATTACK.MAFpt_ATTACK_DB(DOWNLOAD_ATTACK,
                         ATTACK_TAXII_SERVER,
                         ATTACK_LOCAL_FILE_ROOT,
                         ATTACK_LOCAL_COPY,
                         REINDEX_ATTACK,
                         ATTACK_CVE_SEARCH, 
                         ATTACK_MAIN_INDEX, 
                         ATTACK_SUB_INDEX, 
                         ATTACK_CVE_REF_INDEX, 
                         ATTACK_TTP_INDEX, 
                         ATTACK_TACTIC_INDEX, 
                         ATTACK_TECH_TO_TACTIC_INDEX, 
                         ATTACK_REL_INDEX, 
                         ATTACK_TACTIC_BIN_INDEX, 
                         ATTACK_TTP_BIN_INDEX) 
                         
TacticBinDF=ATTACK_obj.GetTacticBinIndex()
                         

from MAFpt_ATTACK_DB_ATTACK_GRAPHS_DATA_AUTO import ThisAttackTechList 

AttackSeqTacticList=[]
AttackSeqTacticCount=[]
IDList=[]
count=0

# Build a list of Tactic sequence pairs. They will be dicts
T1givenT0=[]

# Build a list of Tactic Technique (O) pairs. They will be dicts
TechList=[]
NDepthList=[]

# Fail list
FailList=[]
ShortFailList=[]


# This loop extracts the required pairs from the example Attack
DEPTH=1
OBSERVATION_LEN = 3

ONLY_USE=['ZAPT33_001']

ONLY_USE_LIST_ORIG=['admin@338_001',  'Lazarus_Group_001',  'Lazarus_Group_002',  'APT32_001',  'MuddyWater_001',  'MuddyWater_002',
                           'Mustang_Panda_001',  'Sandworm_001',  'Tropic_Trooper_001',  'APT28_001',  'APT28_002',  'APT28_003',  'APT28_004', 
                           'APT29_001',  'APT29_002',  'APT29_003',  'APT29_004',  'APT41_001', 'APT41_002',  'menuPass_001',  'Carbanak_001',  'APT37_001',  'WizardSpider_001', 
                           'OilRig_001',  'FIN7_001',  'APT3_001']
                           
ONLY_USE_LIST_NEXT=['admin@338_001',  'Lazarus_Group_001',  'Lazarus_Group_002',  'APT32_001',  'MuddyWater_001',  'MuddyWater_002',
                           'Mustang_Panda_001',  'Sandworm_001',  'Tropic_Trooper_001',  'APT28_001',  'APT28_002',  'APT28_003',  'APT28_004', 
                           'APT29_001',  'APT29_002',  'APT29_003',  'APT29_004',  'APT41_001', 'APT41_002',  'menuPass_001',  'Carbanak_001',  'APT37_001',  'WizardSpider_001', 
                           'OilRig_001',  'FIN7_001',  'APT3_001',  'Ajax_Security_Team_001',  'Andariel_001',  'APT38_001']
                           
ONLY_USE_LIST_FULL=['admin@338_001',  'Lazarus_Group_001',  'Lazarus_Group_002',  'APT32_001',  'MuddyWater_001',  'MuddyWater_002',
                           'Mustang_Panda_001',  'Sandworm_001',  'Tropic_Trooper_001',  'APT28_001',  'APT28_002',  'APT28_003',  'APT28_004', 
                           'APT29_001',  'APT29_002',  'APT29_003',  'APT29_004',  'APT41_001', 'APT41_002',  'menuPass_001',  'Carbanak_001',  'APT37_001',  'WizardSpider_001', 
                           'OilRig_001',  'FIN7_001',  'APT3_001',  'Ajax_Security_Team_001',  'Andariel_001',  'APT38_001',  'ZAPT33_001', 
                           'ZAPT19_001', 'ZSandworm_002',  'ZAPT28_005',  'ZAPT32_001',  'ZAPT29_005']
                           
ONLY_USE_LIST_FULL=['admin@338_001',  'Lazarus_Group_001',  'Lazarus_Group_002',  'APT32_001',  'MuddyWater_001',  'MuddyWater_002',
                           'Mustang_Panda_001',  'Sandworm_001',  'Tropic_Trooper_001',  'APT28_001',  'APT28_002',  'APT28_003',  'APT28_004', 
                           'APT29_001',  'APT29_002',  'APT29_003',  'APT29_004',  'APT41_001', 'APT41_002',  'menuPass_001',  'Carbanak_001',  'APT37_001',  'WizardSpider_001', 
                           'OilRig_001',  'FIN7_001',  'APT3_001',  'Ajax_Security_Team_001',  'Andariel_001',  'APT38_001',  'ZAPT33_001', 
                           'ZAPT19_001', 'ZSandworm_002',  'ZAPT28_005',  'ZAPT32_001',  'ZAPT29_005']
                           
#####################################################
#   'Train' MC Transition Matrix based on the test data.
####################################################
SuccPred=0
FailPred=0
UnableToPredict=0
ONLY_USE_LIST_SHORT_FLAGS=[0]*len(ONLY_USE_LIST_FULL)

print("MC Predictions for depth " + str(DEPTH))
print("Using " + str(len(ONLY_USE_LIST_FULL)) + " attacks")

for remove in range(len(ONLY_USE_LIST_FULL)):
    AttackSeqTacticList=[]
    AttackSeqTacticCount=[]
    IDList=[]
    count=0

    # Build a list of Tactic sequence pairs. They will be dicts
    T1givenT0=[]

    # Build a list of Tactic Technique (O) pairs. They will be dicts
    TechList=[]
    NDepthList=[]
    
    # Remove item from the list when building the training set
    ONLY_USE_LIST=[]
    TEST_CASE=[]
    AttackInd=0
    for NextAttack in ONLY_USE_LIST_FULL:
        if not AttackInd == remove:
            ONLY_USE_LIST.append(NextAttack)
        else:
            print("===========   THE TEST CASE IS " + NextAttack)
            TEST_CASE.append(NextAttack)
        AttackInd+=1
        
    # Build lists of techniques and depth-tuples        
    for ThisAttack in ThisAttackTechList:
        IDLine=True
        ThisNDepthItem=[]
        DepthCount=0
        for NextNode in ThisAttack:
            
            if IDLine==True:
                if not NextNode['ID'] in ONLY_USE_LIST:
                    break
                IDLine=False
                IDList.append(NextNode["ID"])
                count+=1
                
            else:
                
                if NextNode['SG'] == 'G':
                    continue
                           
                TTP=NextNode['Tech']
                
                # Add this tech to the list if not already there
                InTechList=False
                for NextTech in TechList:
                    if NextTech == TTP:
                        InTechList=True
                if InTechList == False:
                    TechList.append(TTP)
                    # Add a column of zeros to the matrix
                
                # Create depth long list and append to list if full
                if DepthCount < DEPTH:
                    ThisNDepthItem.append(TTP)
                    DepthCount+=1
                if DepthCount == DEPTH:
                    InNDepthList=False
                    for NextDepthItem in NDepthList:
                        if NextDepthItem == ThisNDepthItem:
                             InNDepthList=True
                    if InNDepthList == False:
                        NDepthList.append(ThisNDepthItem.copy())
                    
                    ThisNDepthItem.pop(0)
                    DepthCount-=1
                    
    # Now create an empty matrix of the required size
    # rows, columns
    
    M=np.zeros((len(NDepthList), len(TechList)))
    
    #print("Processed list is " + str(IDList) + "/Length=" + str(len(IDList)))
    #print("The Tech List is " + str(TechList) + "/Length=" + str(len(TechList)))
    #print("The NDepth List is " + str( NDepthList) + "/Length=" + str(len(NDepthList)))
    #print("The matrix is " + str(M))
    
    # Now put counts in the matrix
    # row is NDepth pattern (index)
    #Column is next Tech
    
    count=0
    for ThisAttack in ThisAttackTechList:
        IDLine=True
        ThisNDepthItem=[]
        DepthCount=0
        prevNDepthList=[]
        for NextNode in ThisAttack:
            
            if IDLine==True:
                if not NextNode['ID'] in ONLY_USE_LIST:
                    break
                IDLine=False
                IDList.append(NextNode["ID"])
                count+=1
                
            else:
                
                if NextNode['SG'] == 'G':
                    continue
                           
                TTP=NextNode['Tech']
                if not prevNDepthList == []:
                    # [Row, Column]
                    M[NDepthList.index(prevNDepthList), TechList.index(TTP)]+=1
                
                         
                # Create depth long list and append to list if full
                if DepthCount < DEPTH:
                    ThisNDepthItem.append(TTP)
                    DepthCount+=1
                if DepthCount == DEPTH:
                    
                    prevNDepthList=ThisNDepthItem.copy()
                    
                    ThisNDepthItem.pop(0)
                    DepthCount-=1
    
                    
                
    
    #print("The matrix is " + str(M))
    
    # Diaplay the non zero elements for checking
    #Count=0
    #for i in range(len(NDepthList)):
    #    for j in range(len(TechList)):
    #        if not M[i, j] == 0:
    #            Count+=1
    #            print(str(Count) + " : For  : " + str(NDepthList[i]) + " / " + str(TechList[j]) + " has " + str(M[i, j]) + " occurrences")
    
    # Now turn counts into probabilities (each row sums to 1)
    RowSums=M.sum(axis=1)
    Row=0
    ZCount=0
    for NextRow in RowSums:
        if NextRow == 0:
            #print("Zero row found for NDepth item " + str(NDepthList[Row]))
            ZCount+=1
        Row+=1
    #print(str(ZCount) + " Zero rows found for NDepth item, out of  " + str(str(len(NDepthList))))
    
    for i in range(len(NDepthList)): # Rows   
        for j in range(len(TechList)): # Columns
            if not RowSums[i] == 0:
                M[i, j]=M[i, j] / RowSums[i]
                       
    
    
    # Now for each attack
    #        Create a test observation set (of the first OBSERVATION_LEN S techniques)
    #        Use the markov matrix to work out the next most likely technique
    #
    
    for ThisAttack in ThisAttackTechList:
        IDLine=True
        ThisNDepthItem=[]
        DepthCount=0
        prevNDepthList=[]
        AttackExtract=[]
        NextTechAfterExtract=""
        AttackName="UNINITIALISED"
        for NextNode in ThisAttack:
            
            if IDLine==True:
                #print("<<<<<<< " + NextNode['ID'])
                if not NextNode['ID'] in TEST_CASE:
                    
                    break
                #print("<x<<x<<x<< Setting attack name for " + NextNode['ID'])    
                AttackName=NextNode['ID']
                IDLine=False
                IDList.append(NextNode["ID"])
                count+=1
                
            else:
                
                if NextNode['SG'] == 'G':
                    continue
                    
                                  
                if len(AttackExtract) < OBSERVATION_LEN:                
                    AttackExtract.append(NextNode['Tech'])
                    continue
                if NextTechAfterExtract == "":
                    #print("<x<<x<<x<< NextTechAfterExtract  for " + AttackName)  
                    NextTechAfterExtract=NextNode['Tech']
                    #continue
                #else:
                if not NextTechAfterExtract == "":
                    # Get the last DEPTH techs from extract
                    AttackExtractToDepth=AttackExtract[-DEPTH:]
                    print("The attack extract for " + AttackName + "is " + str(AttackExtract))
                    print("The depth items are" + str(AttackExtractToDepth) + " (for depth " + str(DEPTH) + ")")
                    # Now get most probable next tech from the relevant matric row
                    ColCount=0
                    MaxCol=0
                    try:
                        MaxMatchList=[] # In case of multiple each probs
                        for NextCol in M[NDepthList.index(AttackExtractToDepth)]:
                            if NextCol > MaxCol:
                                MaxMatchLiist=[]
                                MaxCol=ColCount
                            if NextCol == MaxCol:
                                MaxMatchList.append(ColCount)
                            ColCount+=1
                    except ValueError:
                        UnableToPredict+=1
                        print("<< WARN>> Cannot predict for " + str(AttackExtractToDepth) + " in " + AttackName) 
                        FailList.append({'Attack': AttackName, 'Depth Tuple':AttackExtractToDepth})
                        break
                    print("The predicted technique is " + str(TechList[MaxCol]) )
                    print("The real technique is " + str(NextTechAfterExtract)) 
                    if TechList[MaxCol] == NextTechAfterExtract:
                        SuccPred+=1
                    else:
                        FailPred+=1
                    break             
                                                
                TTP=NextNode['Tech']
                
        #print("<><><> END OF ATTACK LOOP FOR " + AttackName)
        #print("The real technique is " + str(NextTechAfterExtract))
                
        if not AttackName == "UNINITIALISED": 
            
            if NextTechAfterExtract == "":
                
                ONLY_USE_LIST_SHORT_FLAGS[ONLY_USE_LIST_FULL.index(AttackName)]=1
                ShortFailList.append(AttackName)
                    
                
            
    #print("XXXXXXXX END OF CYCLE")
    #print("")
#print("ID LIST IS " +str(IDList))
TooShort=0
for NextFlag in ONLY_USE_LIST_SHORT_FLAGS:
    if NextFlag==1:
        TooShort+=1
print("Total tests is " + str(len(ONLY_USE_LIST_FULL)))
print("Final result ( " + str(DEPTH) + " ) is success=" + str(SuccPred) + "/failed=" + str(FailPred) + "/unable=" + str(UnableToPredict) + "/too short=" + str(TooShort))
print ("Overall accuracy % is " + str(SuccPred/len(ONLY_USE_LIST_FULL)*100))
print ("Accuracy % against runnable is " + str(SuccPred/(len(ONLY_USE_LIST_FULL)-UnableToPredict-TooShort)*100))
print("Fail list is "  +  str(FailList)  ) 
print("Short fail list is "  +  str(ShortFailList)  ) 
