import numpy as np
import math
import socket
import time
from mpi4py import MPI

import compute_path
import tools
import launcher_SUBP

def ACO(Me, NbP, comm, alpha, rho, Q, nb_ants, tau_0, n_iter, n1_max=256, n2_max=256, n3_max=256,
        fancy_strategy='AS', sub_threshold=0.1, sup_threshold=10**9, nb_threads=8, reps=100):
    """ Ant colony optimization of the cache blocking parameters for the execution of
    the iso3dfd programm.

    Args:
        Me (int): index of process running the function
        NbP (int): number of processes
        comm (MPI object): mpi communication object
        alpha (float): hyperparameter alpha of ACO
        rho (float): evaporation rate between 0 and 1
        Q (float): quantity of pheromones deposited by an ant on an edge
        nb_ants (int): number of ants
        tau_0 (float): initial quantity of pheromones on each edge
        n_iter (int): number of cycles done for ACO
        n1_max (int, optional): Maximal first dimension of the problem. Defaults to 256.
        n2_max (int, optional): Maximal second dimension of the problem. Defaults to 256.
        n3_max (int, optional): Maximal third dimension of the problem. Defaults to 256.
        fancy_strategy (string, optional): Strategy used to update tau. Defaults to 'AS'.
        sub_threshold (float, optional): Inferior threshold for Min-Max strategy. Defaults to 0.1.
        sup_threshold (float, optional): Superior threshold for Min-Max strategy. Defaults to 10**9.
        nb_threads (int, optional): Number of threads per MPI process. Defaults to 4.
        reps (int, optional): Max number of iteration before stopping the process. Defaults to 100.

    Returns:
        (tuple: list, float): optimal path [n1, n2, n3, cbx, cby, cbz] and the associated cost
    """

    # Initialisation of the graph and pheromon matrix
    n_cbx = n1_max//16
    n_cby = n2_max
    n_cbz = n3_max
    n1_size = int(1 + np.log2(n1_max/256))
    n2_size = int(1 + np.log2(n2_max/256))
    n3_size = int(1 + np.log2(n3_max/256))
    tau = np.zeros((n1_size + n2_size + n3_size + n_cbx + n_cby + 1, n1_size + n2_size + n3_size + n_cbx + n_cby + n_cbz + 1), dtype="float64")
    tau[0, 1:(n1_size+1)] = tau_0
    tau[1:(n1_size+1), (n1_size+1):(n1_size + n2_size + 1)] = tau_0
    tau[(n1_size+1):(n1_size + n2_size + 1), (n1_size + n2_size + 1):(n1_size + n2_size + n3_size + 1)] = tau_0
    tau[(n1_size + n2_size + 1):(n1_size + n2_size + n3_size + 1), (n1_size + n2_size + n3_size + 1):(n1_size + n2_size + n3_size + n_cbx + 1)] = tau_0
    tau[(n1_size + n2_size + n3_size + 1):(n1_size + n2_size + n3_size + n_cbx + 1), (n1_size + n2_size + n3_size + n_cbx + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + 1)] = tau_0
    tau[(n1_size + n2_size + n3_size + n_cbx + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + 1), (n1_size + n2_size + n3_size + n_cbx + n_cby + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + n_cbz + 1)] = tau_0

    cost_opti = math.inf
    # n_iter cycles of ants traveling through the graph
    for iter in range(n_iter):
        paths = []
        costs = []
        for k in range(nb_ants//NbP):
            path = compute_path.compute_path(tau, alpha, n1_size, n2_size, n3_size)
            n1 = 256 * (2**(path[0]-1))
            n2 = 256 * (2**(path[1]-1))
            n3 = 256 * (2**(path[2]-1))
            cbx = path[3]*16
            cby = path[4]
            cbz = path[5]
            cost = launcher_SUBP.deploySUBP(n1, n2, n3, nb_threads, reps, cbx, cby, cbz)
            cost = [cost]
            print(f"Path followed at iteration {iter} on process {Me} by ant {k} : {[n1, n2, n3, cbx, cby, cbz]} with cost equal to {cost[0]}.")
            paths.append(path)
            costs.append(cost)
        paths = np.array(paths, dtype="int32")
        costs = np.array(costs, dtype="float64")

        if fancy_strategy == 'AS' or fancy_strategy == 'ElitistAS':
            #we initialize the best path with the ants already in this process
            best_cost = np.amin(costs)
            best_p = paths[np.argmin(costs), :]

            #we vaporize the pheromon
            tau = tau * (1-rho)

            #we compute the new pheromon with the ants already present
            tau = tools.add_pheromones(tau, paths, costs, Q, n1_size, n2_size, n3_size, n_cbx, n_cby)

            for i in range(1,NbP):
                #we exchange the ants with a ring structure to compute locally tau
                comm.Sendrecv_replace(paths, dest=(Me+1)%NbP, source=(Me-1)%NbP)
                comm.Sendrecv_replace(costs, dest=(Me+1)%NbP, source=(Me-1)%NbP)
                tau = tools.add_pheromones(tau, paths, costs, Q, n1_size, n2_size, n3_size, n_cbx, n_cby)
                #we search for the best path
                new_cost = np.amin(costs)
                if new_cost < best_cost:
                    best_cost = new_cost
                    best_p = paths[np.argmin(costs), :]

            if fancy_strategy == 'ElitistAS':
                tau = tools.add_pheromones(tau, np.array(best_p), np.array(best_cost), Q, n1_size, n2_size, n3_size, n_cbx, n_cby)

        if fancy_strategy == 'MMAS':
            #we compute the best path locally
            best_cost = np.amin(costs)
            best_p = paths[np.argmin(costs), :]

            #we compare the best paths of all process and we save only the best one
            recv_table = np.empty(2)
            comm.Allreduce(np.array([best_cost, Me]), recv_table, OP=MPI.MINLOC)
            best_cost = recv_table[0]
            best_p = comm.bcast(best_p, root=recv_table[1])

            #we compute the pheromon using only the best path
            tau = tau * (1-rho)
            tau = tools.add_pheromones(tau, np.array(best_p), np.array(best_cost), Q, n1_size, n2_size, n3_size, n_cbx, n_cby)

            #verification of the threshold constraint
            size = np.size(tau)
            for i in range(size[0]):
                for j in range(size[1]):
                    if tau[i][j] < sub_threshold:
                        tau[i][j] = sub_threshold
                    if tau[i][j] > sup_threshold:
                        tau[i][j] = sup_threshold

        #we update the best path on all iteration
        if best_cost < cost_opti:
            cost_opti = best_cost
            path_opti = best_p
        if Me == 0:
            with open("cost_opti.txt", 'a') as file:
                file.write(f"{-1*cost_opti}\n")

    n1 = 256 * (2**(path_opti[0]-1))
    n2 = 256 * (2**(path_opti[1]-1))
    n3 = 256 * (2**(path_opti[2]-1))
    cbx= path_opti[3]*16
    cby= path_opti[4]
    cbz= path_opti[5]
    return [n1, n2, n3, cbx, cby, cbz], cost_opti


if __name__ == "__main__":
    #MPI information extraction
    comm = MPI.COMM_WORLD
    NbP = comm.Get_size()
    Me  = comm.Get_rank()
    print(f"Process {Me} launched on node {socket.gethostname()}.")

    # Initialization of hyperparameters
    alpha = 1
    rho = 0.1
    Q = 0.1
    nb_ants = 2
    tau_0 = Q*100
    n_iter = 1
    fancy_strategy = "AS" 

    # Parameters for compilation and execution of iso3dfd
    nb_threads = 8
    reps = 100
    n1_max, n2_max, n3_max = 256, 256, 256
    optimization = "-O3"
    simd = "avx512"

    if Me == 0 :
        # Compiling the code
        launcher_SUBP.compileSUBP(optimization=optimization, simd=simd)
    # Waiting for the compilation to end on other processes
    comm.Barrier()

    if Me == 0:
        begin = time.time()

    path_opti, cost_opti = ACO(Me, NbP, comm, alpha, rho, Q, nb_ants, tau_0, n_iter,
                                n1_max=n1_max, n2_max=n2_max, n3_max=n3_max, fancy_strategy=fancy_strategy,
                                nb_threads=nb_threads, reps=reps)
    if Me == 0:
        end = time.time()
        print(f"Le temps d'exécution total est {end - begin} s.")
        print(f"Le chemin optimal est {path_opti}.")
        print(f"Le throughput associé est alors {-1*cost_opti} MPoints/s.")