import argparse
import re
import math

#--------------------------------------------------------------------------
# Cmd line parsing 
#--------------------------------------------------------------------------

def cmdLineParsing():
    n1 = 256
    n2 = 256
    n3 = 256
    nb_threads = 8
    reps = 100
    cbx = 32
    cby = 32
    cbz = 32

    parser = argparse.ArgumentParser()
    parser.add_argument("--n1", help="n1", default=n1, type=int)
    parser.add_argument("--n2", help="n1", default=n2, type=int)
    parser.add_argument("--n3", help="n1", default=n3, type=int)
    parser.add_argument("--thds", help="Number of threads", default=nb_threads, type=int)
    parser.add_argument("--reps", help="Number of reps", default=reps, type=int)
    parser.add_argument("--cbx", help="n1 Thread block size", default=cbx, type=int)
    parser.add_argument("--cby", help="n2 Thread block size", default=cby, type=int)
    parser.add_argument("--cbz", help="n3 Thread block size", default=cbz, type=int)

    args = parser.parse_args()

    return args.n1, args.n2, args.n3, args.thds, args.reps, args.cbx, args.cby, args.cbz


#-----------------------------------------------------------------
# Extracting function
#-----------------------------------------------------------------

def commandLineExtract(commandLineOutput):
    """Extract time, throughput and GFlops from the mpirun execution

    Args:
        commandLineOutput (string): Corresponds to the command line output of the exeFile execution

    Returns:
        list[float]: List of executions times
        list[float]: List of executions throughputs
        list[float]: List of executions Gflops
        list[list]:
            list[float]: [time, throughput, gflop] Executions caracteristics

    """

    # Initialize outputs
    times = []
    throughputs = []
    gflops = []
    byRun = []

    # Setup for one execution
    run = []

    # Extract caracteristics
    for line in str(commandLineOutput.stdout,'utf-8').split('\n'):
        # Extract time
        if "time" in line:
            # Find float in line
            time = float(re.findall("\d+\.\d+", line)[0])
            times.append(time)
            run.append(time)

        # Extract troughput speed
        if "throughput" in line:
            # Find float in line
            throughput = float(re.findall("\d+\.\d+", line)[0])
            throughputs.append(throughput)
            run.append(throughput)

        if "flops" in line:
            # Find GFlops in line
            gflop = float(re.findall("\d+\.\d+", line)[0])
            gflops.append(gflop)
            run.append(gflop)
            byRun.append(run)

            # Reset execution caracteristics
            run = []

    return times, throughputs, gflops, byRun


#-----------------------------------------------------------------
# Add pheromones function
#-----------------------------------------------------------------

def add_pheromones(tau, paths, costs, Q, n1_size, n2_size, n3_size, n_cbx, n_cby):
    """Computing of the pheromon matrix and the cost of each path.

    Args:
        tau (np.array): Pheromon matrix
        paths (np.array): Array of the paths taken by the ants: paths[i]=(cb_x//16,cb_y,cb_z)
        costs (np.array): Column array containing the costs.
        Q (float): quantity of pheromones added by an ant on an edge
        n1_size (int): First dimension of the problem
        n2_size (int): Second dimension of the problem
        n3_size (int): Third dimension of problem
        n_cbx (int): First dimension of the cache
        n_cby (int): Second dimension of the cache

    Returns:
        (np.array): the updated pheromon matrix
    """

    for i in range(len(costs)):
        p = paths[i, :]
        cost = costs[i, 0]
        tau[0,p[0]] += Q*(-1)*cost
        tau[p[0], p[1] + n1_size] += Q*(-1)*cost
        tau[p[1] + n1_size, p[2] + n1_size + n2_size] += Q*(-1)*cost
        tau[p[2] + n1_size + n2_size, p[3] + n1_size + n2_size + n3_size] += Q*(-1)*cost
        tau[p[3] + n1_size + n2_size + n3_size, p[4] + n1_size + n2_size + n3_size + n_cbx] += Q*(-1)*cost
        tau[p[4] + n1_size + n2_size + n3_size + n_cbx, p[5] + n1_size + n2_size + n3_size + n_cbx + n_cby] += Q*(-1)*cost

    return tau