# -*- coding: utf-8 -*-
# coding: utf-8
import numpy as np
import scipy as sc
import scipy.signal as ss
import os
from sklearn import mixture
import collections
import datetime
from GrASP import Params_GrASP


def Spectral_clustering_EQT (Stadis_array, Probability, Threshold, Normalized_mode = 'on'):
    """ 
    Categorization for tremor candidates

    Parameters:
    ----------
    Stadis_array: Permanent adjacency matrix
    Probability: Tremor and EQ probability
    Threshold: Threshold for probability
    Normalized_mode: Use normalized laplacian. default: 'on'

    Returns:
    ----------
    pred_output: Cluster ID 
    pred_keys_output: Key for Cluster ID
    Adjacency_ID_output: Station ID 
    output_type: In case 1, one cluster was extracted; in case 2, two or more clusters were extracted.
    """

    ##### Adjacency matrix #####
    Adjacency_matrix_All = Stadis_array.copy()
    Adjacency_matrix_All[Probability < Threshold, :] = 0
    Adjacency_matrix_All[:, Probability < Threshold] = 0
    Adjacency_matrix_All[Adjacency_matrix_All > 0] = 1 
    Adjacency_ID = np.array([], dtype = 'int')
    for id in range(np.shape(Adjacency_matrix_All)[0]):
        if np.sum(Adjacency_matrix_All[id,:]) > 0:
            Adjacency_ID = np.append(Adjacency_ID, id)

    if len(Adjacency_ID) > 1:
        Adjacency_matrix = []
        for ad_id in Adjacency_ID:
            Adjacency_matrix.append(Adjacency_matrix_All[ad_id, Adjacency_ID])
        Adjacency_matrix = np.array(Adjacency_matrix)
        ##### Degree matrix #####
        Degree_matrix = np.identity(np.shape(Adjacency_matrix)[0])
        Degree_matrix_inv = np.identity(np.shape(Adjacency_matrix)[0])
        for diagonal in range(np.shape(Adjacency_matrix)[0]):
            Degree_matrix[diagonal, diagonal] = np.sum(Adjacency_matrix[:, diagonal])
            Degree_matrix_inv[diagonal, diagonal] = (np.sum(Adjacency_matrix[:, diagonal]))**(-0.5)
        ##### Laplacian matrix #####
        Laplacian_matrix = Degree_matrix - Adjacency_matrix
        ##### Normalized Laplacian #####
        Normalized_Laplacian = np.dot(Degree_matrix_inv, Laplacian_matrix)
        Normalized_Laplacian = np.dot(Normalized_Laplacian, Degree_matrix_inv)
        if Normalized_mode == 'on':
            Laplacian_matrix = Normalized_Laplacian

        ##### Eigenvalue decomposition #####
        rank = np.linalg.matrix_rank(Laplacian_matrix)
        eigval, eigvec = sc.linalg.eigh(Laplacian_matrix)
        eigval = np.round(eigval, 7)
        eigvec = np.round(eigvec, 7) 
        zero_number = len(eigval[eigval == 0])

        ##### One cluster #####
        if zero_number == 1:
            eigvec_mixture = eigvec[:, eigval <= eigval[0]]
            ##### Gaussian mixture model #####
            pred = np.zeros(np.shape(eigvec_mixture)[0], dtype='int')
            pred_count = collections.Counter(pred)

            pred_output = pred
            pred_keys_output = pred_count.keys()
            Adjacency_ID_output = Adjacency_ID
            output_type = 1

        ##### Two or more clusters #####
        elif zero_number > 1:
            eigvec_mixture = eigvec[:, eigval <= eigval[zero_number-1]]
            ##### Gaussian mixture model #####
            GMM = mixture.GaussianMixture(n_components = zero_number, covariance_type = 'full', init_params='k-means++')
            GMM.fit(eigvec_mixture)
            pred = GMM.predict(eigvec_mixture)
            pred_count = collections.Counter(pred)
            pred_output = []
            pred_keys_output = []
            Adjacency_ID_output = []
            output_type = 2
            ##### Each cluster #####
            for multi_id in pred_count.keys():
                Adjacency_ID_multi = Adjacency_ID[pred == multi_id]
                Laplacian_matrix_multi = []
                for sub_id in range(len(Adjacency_ID)):
                    if pred[sub_id] == multi_id:
                        Laplacian_matrix_multi.append(Laplacian_matrix[sub_id, pred == multi_id])
                Laplacian_matrix_multi = np.array(Laplacian_matrix_multi)

                rank = np.linalg.matrix_rank(Laplacian_matrix_multi)
                eigval_multi, eigvec_multi = sc.linalg.eigh(Laplacian_matrix_multi)
                eigval_multi = np.round(eigval_multi, 7) 
                eigvec_multi = np.round(eigvec_multi, 7) 
                eigvec_mixture_multi = eigvec_multi[:, eigval_multi <= eigval_multi[0]]

                ##### Gaussian mixture model #####
                pred_multi = np.zeros(np.shape(eigvec_mixture_multi)[0], dtype='int')
                pred_count_multi = collections.Counter(pred_multi)
                pred_output.append(pred_multi)
                pred_keys_output.append(pred_count_multi.keys())
                Adjacency_ID_output.append(Adjacency_ID_multi)
                
    else: 
        pred_output = []
        pred_keys_output = []
        Adjacency_ID_output =[]
        output_type = 0

    return pred_output, pred_keys_output, Adjacency_ID_output, output_type


def Apply_Categorization (Stadis_array, Data_T, Data_EQ):
    """ 
    Apply Tremor categorization

    Parameters:
    ----------
    Stadis_array: Permanent adjacency matrix
    Data_T: Tremor Association result
    Data_EQ: EQ Association result

    Return:
    ----------
    Save_result: Station association result with assigned categories
    """

    Tremor_category = []
    search_id = 0

    YMDHm = Data_T[:, :5]
    Probability_T = Data_T[:, 5:np.shape(Data_T)[1]]
    Probability_EQ = Data_EQ[:, 5:np.shape(Data_EQ)[1]]
    
    for event_T in range(np.shape(Data_T)[0]):
        Target_window_T = datetime.datetime(int(Data_T[event_T,0]), int(Data_T[event_T,1]), int(Data_T[event_T,2]), int(Data_T[event_T,3]), int(Data_T[event_T,4]), 0)

        EQT_count = 0
        Marge_count = 0
        Tremor_prob = Probability_T[event_T,:]
        Marge_prob = Tremor_prob

        if len(Tremor_prob[Tremor_prob >= Params_GrASP.Threshold]) < Params_GrASP.Min_sta:
            Tremor_category.append(0)
            continue

        else:
            for event_EQ in range(search_id, np.shape(Data_EQ)[0]):
                Target_window_EQ = datetime.datetime(int(Data_EQ[event_EQ,0]), int(Data_EQ[event_EQ,1]), int(Data_EQ[event_EQ,2]), int(Data_EQ[event_EQ,3]), int(Data_EQ[event_EQ,4]), 0)
                if Target_window_T == Target_window_EQ:
                    if Marge_count == 0:
                        search_id = event_EQ
                    EQ_prob = Probability_EQ[event_EQ,:]
                    Marge_prob = Marge_prob + EQ_prob
                    Marge_count += 1
                elif Target_window_T < Target_window_EQ:
                    break

            if Marge_count > 0:
                ##### Spectral clustering #####
                pred_output, pred_keys_output, Adjacency_ID_output, output_type = Spectral_clustering_EQT(Stadis_array, Marge_prob, Params_GrASP.Threshold, Normalized_mode = 'on')
                if output_type == 1:
                    EQT_count += 1
                elif output_type == 2:
                    if Marge_count + 1 == len(pred_output):
                        EQT_count += 0
                    elif Marge_count + 1 > len(pred_output): 
                        EQT_count += 1
            if EQT_count == 0:
                Tremor_category.append(1)
            elif EQT_count > 0:
                Tremor_category.append(2)

    Tremor_category = np.array(Tremor_category)
    Save_result = np.hstack((Data_T, Tremor_category[:,np.newaxis]))

    return Save_result
