import matplotlib.pyplot as plt
import math

###########################################################################
## Test stability of the code
###########################################################################

for i in range(1, 7):
    throughputs = []
    iter_list = []
    iter = 1
    with open(f"16np_32ants_400iter/cost_opti_16np_32ants_400iter_{i}.txt", 'r') as file:
        for line in file:
            cost = float(line.strip("\n"))
            throughputs.append(cost)
            iter_list.append(iter)
            iter +=1
        plt.plot(iter_list, throughputs, '*-', label=f"Batch {i}")

plt.title("Résulats de plusiseurs batchs pour 400 générations de 32 fourmis sur 16 process.")
plt.xlabel("Iteration")
plt.ylabel("Throughput optimal (MPoint/s)")
plt.legend()
plt.show()

###########################################################################
## Test stability of path at last iteration
###########################################################################

batches = [1, 2, 3, 4, 5, 6]
fig, axs = plt.subplots(2, 3)
max_nb_paths = 0
width = 0.5

for batch in batches:
    paths = []
    count = []
    costs= []
    with open(f"16np_32ants_400iter/result_16np_32ants_400iter_{batch}.out", "r") as file:
        for line in file:
            if line[0:4] == "Path":
                iteration = int(line.split(" ")[4])
                if iteration == 399:
                    path = (line.split("[")[1]).split("]")[0]
                    cost = -1*float((line.split(" ")[-1]).strip(".\n"))
                    if path in paths:
                        i = paths.index(path)
                        count[i] +=1
                        costs[i] += cost
                    else:
                        paths.append(path)
                        count.append(1)
                        costs.append(cost)
    for i in range(len(paths)):
        costs[i] = costs[i]/count[i]
    costs, paths, count = map(list, zip(*sorted(zip(costs, paths, count), reverse=True)))
    if len(paths) > max_nb_paths:
        max_nb_paths = len(paths)
    axs[(batch-1)%2, math.floor((batch-1)/2)].bar(range(len(paths)), count, width=width)
    axs[(batch-1)%2, math.floor((batch-1)/2)].set_title(f"Batch {batch}")

for batch in batches:
    axs[(batch-1)%2, math.floor((batch-1)/2)].set_xlim(-2*width, max_nb_paths + 2*width)
axs[0, 0].set_ylabel("Nombre de fourmis par chemin")
axs[0, 0].set_xlabel("Chemins triés par throughput décroissant")
fig.suptitle("Chemins suivis à la dernière itération pour différents batchs.")
fig.tight_layout(pad=0.5)
plt.show()
plt.close(10)