import random as rd
import numpy as np

#tau : pheromone matrix, size : 1 + n1_size + n2_size + n3_size + n1_max//16 + n2_max x 1 + n1_size + n2_size + n3_size + n1//16 + n2 + n3

def proba(i, alpha, tau, n1_size, n2_size, n3_size, n_cbx, n_cby, n_cbz):
    """Based on i, deduce which parameter you want to choose from. 
    Returns possibles choices and associated weights. 

    Args:
        i (int): Index in matrix
        alpha (float): Parameter to compute weights
        tau (np.array): Pheromon matrix
        n1_size (int): First dimension of the problem
        n2_size (int): Second dimension of the problem
        n3_size (int): Third dimension of the problem
        n_cbx (int): First dimension of the cache
        n_cby (int): Second dimension of the cache
        n_cbz (int): Third dimension of the cache

    Returns:
        (tuple: np.array, np.array): possibles choices (sequence) and their weights (weights)
    """

    n1_max = (2**(n1_size-1)) * 16
    n2_max = (2**(n2_size-1)) * 256
    n3_max = (2**(n3_size-1)) * 256

    if i == 0 :
        #we are on the initial state
        #the ant is going to choose n1
        sequence = np.arange(1, n1_size +1)

    elif i > 0 and i < n1_size + 1 :
        #we are on the first state
        #the ant is going to choose n2
        sequence = np.arange(n1_size + 1, n1_size + n2_size + 1)
    
    elif i > n1_size and i < n1_size + n2_size + 1 :
        #we are on the second state
        #the ant is going to choose n3
        sequence = np.arange(n1_size + n2_size + 1, n1_size + n2_size + n3_size + 1)
    
    elif i > n1_size + n2_size and i < n1_size + n2_size + n3_size + 1 :
        #we are on the third state
        #the ant is going to choose cbx
        sequence = np.arange(n1_size + n2_size + n3_size + 1, n1_size + n2_size + n3_size + n_cbx + 1)
    
    elif i > n1_size + n2_size + n3_size and i < n1_size + n2_size + n3_size + n1_max + 1 :
        #we are on the fourth state
        #the ant is going to choose cby
        sequence = np.arange(n1_size + n2_size + n3_size + n1_max + 1, n1_size + n2_size + n3_size + n1_max + n_cby + 1)
    
    elif i > n1_size + n2_size + n3_size + n1_max and i < n1_size + n2_size + n3_size + n1_max + n2_max + 1 :
        #we are on the fifth state
        #the ant is going to choose cbz
        sequence = np.arange(n1_size + n2_size + n3_size + n1_max + n2_max + 1, n1_size + n2_size + n3_size + n1_max + n2_max + n_cbz + 1)
    
    weights = np.power(tau[i][sequence[0]:sequence[-1]+1],alpha)
    #Uncomment the two next lines to normalize the weights
    #norm = np.sum(weights)
    #weights = weights/norm

    return (sequence, weights)

def compute_path(tau, alpha, n1_size, n2_size, n3_size):
    """From the sizes of the problem and the pheromon matrix, give a ant path.

    Args:
        tau (np.array): Pheromon matrix
        alpha (foat): Parameter impacting weights when choosing path
        n1_size (int): First dimension of the problem
        n2_size (int): Second dimension of the problem
        n3_size (int): Third dimension of the problem

    Returns:
        (list[int]): path choosen by the ant (n1, n2, n3, cbx, cby, cbz)
    """

    #n_cbx, n_cby, n_cbz will be initialized after the choice of n1, n2, n3
    n_cbx = 0
    n_cby = 0
    n_cbz = 0
    path = []

    sequence, weights = proba(0, alpha, tau, n1_size, n2_size, n3_size, n_cbx, n_cby, n_cbz)
    new_node = rd.choices(sequence,weights)[0]
    path.append(new_node)

    for i in range(5):
        sequence, weights = proba(path[i], alpha, tau, n1_size, n2_size, n3_size, n_cbx, n_cby, n_cbz)
        new_node = rd.choices(sequence,weights)[0]
        path.append(new_node)
        if i == 1:
            n1 = 256 * (2**(path[0]-1))
            n2 = 256 * (2**(path[1] - n1_size - 1))
            n3 = 256 * (2**(path[2] - n1_size - n2_size - 1))
            n_cbx, n_cby, n_cbz = n1//16, n2, n3

    #we transform the path so that it is now in the following form : [n1, n2, n3, n_cbx, n_cby, n_cbz]   
    path[1] = path[1] - n1_size
    path[2] = path[2] - n1_size - n2_size
    path[3] = path[3] - n1_size - n2_size - n3_size
    path[4] = path[4] - n1_size - n2_size - n3_size - (2**(n1_size-1)) * 16
    path[5] = path[5] - n1_size - n2_size - n3_size - (2**(n1_size-1)) * 16 - (2**(n2_size-1)) * 256

    return path

if __name__ == "__main__": 
    n1_max, n2_max, n3_max = 1024, 1024, 1024
    n_cbx = n1_max//16
    n_cby = n2_max
    n_cbz = n3_max
    n1_size = int(1 + np.log2(n1_max/256))
    n2_size = int(1 + np.log2(n2_max/256))
    n3_size = int(1 + np.log2(n3_max/256))
    tau_0 = 10

    tau = np.zeros((n1_size + n2_size + n3_size + n_cbx + n_cby + 1, n1_size + n2_size + n3_size + n_cbx + n_cby + n_cbz + 1), dtype="float64")
    tau[0, 1:(n1_size+1)] = tau_0
    tau[1:(n1_size+1), (n1_size+1):(n1_size + n2_size + 1)] = tau_0
    tau[(n1_size+1):(n1_size + n2_size + 1), (n1_size + n2_size + 1):(n1_size + n2_size + n3_size + 1)] = tau_0
    tau[(n1_size + n2_size + 1):(n1_size + n2_size + n3_size + 1), (n1_size + n2_size + n3_size + 1):(n1_size + n2_size + n3_size + n_cbx + 1)] = tau_0
    tau[(n1_size + n2_size + n3_size + 1):(n1_size + n2_size + n3_size + n_cbx + 1), (n1_size + n2_size + n3_size + n_cbx + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + 1)] = tau_0
    tau[(n1_size + n2_size + n3_size + n_cbx + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + 1), (n1_size + n2_size + n3_size + n_cbx + n_cby + 1):(n1_size + n2_size + n3_size + n_cbx + n_cby + n_cbz + 1)] = tau_0

    path = compute_path(tau, 1, n1_size, n2_size, n3_size)
    print(path)
