from enum import IntEnum, unique
import numpy as np
import itertools as it

@unique
class Action(IntEnum):
    UP = 1
    DOWN = 2
    LEFT = 3
    RIGHT = 4

class RandomAgent():
    def __init__(self):
        """
        Initialize your internal state
        """
        pass

    def act(self):
        """
        Choose action depending on your internal state
        """
        return Action(np.random.randint(4)+1)


    def update(self, next_state, reward):
        """
        Update your internal state
        """
        pass

class qLearningAgent:
    def __init__(self, mat, TD_lambda = 0.95, alpha = 0.05, gamma = 0.95, epsilon = 0.0, verbose= True):
        self.state_per_tile = 12
        self.alpha = alpha
        self.gamma = gamma
        self.epsilon = epsilon # e-greedy
        # self.TD_lambda = 1-epsilon # TD(lamdba)
        self.TD_lambda = TD_lambda
        self.tuple = self._tuple()
        if verbose:
            print(len(self.tuple))
        self.W = self._generate_dict()
        self.N = self._generate_dict()
        if verbose:
            print(sum([len(w.keys()) for w in self.W]))
        self.feature_size = sum([self.state_per_tile**len(k) for k in self.tuple])
        self.set_state(mat)
        if verbose:
            print(self.feature_size)
        self.verbose = verbose
        self.reset()

    # [[(0,0),(1,0),(2,0),(3,0)],\
    #                 [(0,1),(1,1),(2,1),(3,1)],\
    #                 [(0,1),(1,1),(2,1),(0,2),(1,2),(2,2)],\
    #                 [(0,2),(1,2),(2,2),(0,3),(1,3),(2,3)]]

    def _tuple(self):
        list=[]
        for i in range(4):
            l = []
            for j in range(4):
                l+=[(i,j)]
            list+=[l]
        for i in range(4):
            l = []
            for j in range(4):
                l+=[(j,i)]
            list+=[l]
        for i in range(3):
            for j in range(3):
                l = [(i,j),(i,j+1),(i+1,j),(i+1,j+1)]
                list += [l]
        print(list)
        return list
        
    def reset(self):
        self._reset_trace() #eligibility trace
        self.count = 0
        self.first_step = True# used to avoid update the first time
        return

    def _reset_trace(self):
        self.trace = []
        for t in self.tuple:
            self.trace += [dict()]
        return

    def _generate_dict(self):
        container = [] #weight(Theta)
        for t in self.tuple:
            d = dict()
            l = list(range(self.state_per_tile))
            for k in list(it.product(l,repeat=len(t))):
                d[k] = np.zeros(len(Action))
            container += [d]
        return container

    def _index(self, state):
        return [self._calculate_index(state,t) for t in self.tuple]

    def _phi(self, state = None):
        #value function
        if state == None:
            return np.sum(np.array([w[i] for w,i in zip(self.W, self.index)]),axis=0)
        else:
            # print(self.W[self._index(state),:])
            return np.sum(np.array([w[i] for w,i in zip(self.W, self._index(state))]),axis=0)

    def _calculate_index(self, state, t):
        comb = []
        for r,l in t:
            if state[r][l] != 0:
                comb += [int(np.log2(state[r][l]))]
            else:
                comb += [0]
        return tuple(comb)

    # def _size(self, mat):
    #     return len(mat)*len(mat)
        
    def act(self):
        i = np.random.rand();
        if i > self.epsilon:
            #e-greedy
            #exploitation
            self.forget = self.TD_lambda
            action_index = np.argmax(self._phi())
            # print(self._phi())
        else:
            # exploration
            self.forget = 0.0
            action_index = np.random.randint(0, len(Action))
        self._action_index = action_index
        return Action(action_index+1)

    def _update_trace(self):
        if self.forget != 0.0:
            for d in self.trace:
                l = list(d.items())
                for k,v in l:
                    upd = v*self.forget*self.gamma
                    if np.all(upd < 0.01):
                        d.pop(k)
                    else:
                        d[k]=upd

        else:
            self._reset_trace()
        for tr, ind in zip(self.trace, self.index):
            v = tr.get(ind,np.zeros(len(Action)))
            v[self._action_index] += 1 
            tr[ind] = v
        # print(self.trace[0])
        # print(np.sum(self.trace,axis=1))
        pass

    def _target(self,next_state,reward):
        #q learning target function
        return reward + self.gamma * np.max(self._phi(next_state)) - self._phi()[self._action_index]

    def update(self, next_state, reward):
        # print(next_state)
        if self.first_step == True:
            #don't update the first time
            self.first_step = False
            return
        self._update_trace()
        target = self.alpha * self._target(next_state,reward)
        for w,tr,n in zip(self.W,self.trace,self.N):
            for k in tr.keys():
                n[k] += tr[k]
                index = np.where(n[k]!=0)# can't divide by zeros :/
                # print(n[k])

                w[k][index] += target*tr[k][index]/n[k][index]
                # w[k] += target*tr[k]
                # print(w[k])
        if self.verbose:
            print("reward: "+str(reward))
            print("target: "+str(target))
            print("weight: "+str(len(w.keys())))
            # print(w[k])
            # print(np.array([w[i] for w,i in zip(self.W, self.index)]))
            print("trace: "+str(len(tr.keys())))
        # print(self._target(next_state,reward) \
        #                 - self._phi()[self._action_index])
            # #game stops, reset the agent
            # self._reset()
        self.count+= 1
        return

    def set_state(self, state):
        self.state = state
        # print(self.state)
        self.index = self._index(self.state)
        # assert len(self.phi) ==4,"wrong calculation of phi"
        # print(self.index)
        return

# class qLearningAgent2:
#     def __init__(self, mat, TD_lambda = 0.0, alpha = 0.5, gamma = 0.8, epsilon = 0.01):
#         self.state_per_tile = 10
#         self.alpha = alpha
#         self.gamma = gamma
#         self.epsilon = epsilon # e-greedy
#         self.TD_lambda = TD_lambda # TD(lamdba)
#         self.tuple = [[(0,0),(1,0),(2,0),(3,0)],\
#                     [(0,1),(1,1),(2,1),(3,1)],\
#                     [(0,1),(1,1),(2,1),(0,2),(1,2),(2,2)],\
#                     [(0,2),(1,2),(2,2),(0,3),(1,3),(2,3)]]
#         self.feature_size = sum([self.state_per_tile**len(k) for k in self.tuple])
#         self.W = np.zeros((self.feature_size,len(Action))) #weight(Theta)
#         self.set_state(mat)
#         print(self.feature_size)
#         self.reset()
        
#     def reset(self):
#         self.trace = np.zeros((self.feature_size,len(Action))) #eligibility trace
#         self.first_step = True# used to avoid update the first time
#         pass

#     def _index(self, state):
#         #value function
#         sum = 0
#         list_index = []
#         for t in self.tuple:
#             index = self._calculate_index(state,t)
#             # assert sum+index < self.feature_size, "bad calculation of feature index"
#             list_index += [sum+index]
#             sum += self.state_per_tile**len(t)
#         return list_index

#     def _phi(self, state = None):
#         #value function
#         if state == None:
#             return np.sum(self.W[self.index,:],axis=0)
#         else:
#             # print(self.W[self._index(state),:])
#             return np.sum(self.W[self._index(state),:],axis=0)

#     def _phi_gradient(self):
#         #value function
#         res = np.zeros(self.feature_size)
#         res[self.index] = 1
#         return res

#     def _calculate_index(self, state, tuple):
#         sum = 0
#         for r,l in tuple:
#             if state[r][l] != 0:
#                 sum += int(np.log2(state[r][l]))
#             sum *= self.state_per_tile
#         sum /= self.state_per_tile
#         return int(sum)

#     def _size(self, mat):
#         return len(mat)*len(mat)
        
#     def act(self):
#         i = np.random.rand();
#         if i > self.epsilon:
#             #e-greedy
#             #exploitation
#             self.forget = self.TD_lambda
#             action_index = np.argmax(self._phi())
#             # print(self._phi())
#         else:
#             # exploration
#             self.forget = 0.0
#             action_index = np.random.randint(0, len(Action))
#         self._action_index = action_index
#         return Action(action_index+1)

#     def _update_trace(self):
#         self.trace *= self.forget*self.gamma
#         self.trace[:,self._action_index] += self._phi_gradient()
#         # print(np.sum(self.trace,axis=1))
#         pass

#     def _target(self,next_state,reward):
#         #q learning target function
#         return reward + self.gamma * np.max(self._phi(next_state))

#     def update(self, next_state, reward):
#         # print(next_state)
#         if self.first_step == True:
#             #don't update the first time
#             self.first_step = False
#             pass
#         self._update_trace()
#         self.W += self.alpha * (self._target(next_state,reward) \
#                         - self._phi()[self._action_index])\
#                         * self.trace
#         # print(self._target(next_state,reward) \
#         #                 - self._phi()[self._action_index])
#             # #game stops, reset the agent
#             # self._reset()
#         pass

#     def set_state(self, state):
#         self.state = state
#         # print(self.state)
#         self.index = self._index(self.state)
#         # assert len(self.phi) ==4,"wrong calculation of phi"
#         # print(self.index)
#         pass