diff --git a/__pycache__/agent_afterstate.cpython-35.pyc b/__pycache__/agent_afterstate.cpython-35.pyc index 868c7462c3ecc9b239e39999006428952923ed71..c98f07408ce6dd5b46f6601976789ee53da6ad7a 100644 Binary files a/__pycache__/agent_afterstate.cpython-35.pyc and b/__pycache__/agent_afterstate.cpython-35.pyc differ diff --git a/agent_afterstate.py b/agent_afterstate.py index 642b33c8ecfa25974cb1ea7d29b658f8fa2eeef5..0e98b4083d1049ebd9de9f0360d237afd94e7285 100644 --- a/agent_afterstate.py +++ b/agent_afterstate.py @@ -16,6 +16,7 @@ class afterstateAgent: self.commands = { Action.UP: up, Action.DOWN: down, Action.LEFT: left, Action.RIGHT: right} self.alpha = alpha # self.gamma = gamma + self.epsilon_origin = epsilon self.epsilon = epsilon # e-greedy # self.TD_lambda = 1-epsilon # TD(lamdba) self.TD_lambda = TD_lambda @@ -77,6 +78,7 @@ class afterstateAgent: self.count = 0 self.first_step = True# used to avoid update the first time self.explore = 0 + self.epsilon -= self.epsilon_origin/2000 return def _reset_trace(self): @@ -168,6 +170,7 @@ class afterstateAgent: i = np.random.rand(); if i < self.epsilon: #explore self.explore += 1 + self.forget = 0.0 return sum(phi_array) + 10000 return sum(phi_array) + done @@ -200,11 +203,9 @@ class afterstateAgent: if self.symmetric>0: for i in range(4): s = transpose(s) - self.set_state(s) n = transpose(n) self.one_side_update(n,reward,s) s = reverse(s) - self.set_state(s) n = reverse(n) self.one_side_update(n,reward,s) #one loop is one rotation diff --git a/test.py b/test.py index e4dc0fa87de3c2a3d7572230dcd4de4f2cb35fea..ee393dfe1b8cdee419614f2e61a6eb6ea73238a1 100644 --- a/test.py +++ b/test.py @@ -1,15 +1,21 @@ +import matplotlib +matplotlib.use("TkAgg") +import matplotlib.pyplot as plt + from tkinter import * from logic import * from random import * from agent import * from agent_afterstate import * - import numpy as np import pickle import time +import sys +from optparse import OptionParser +import os -TRAIN = 100000 +TRAIN = 2000 SIZE = 500 GRID_LEN = 4 GRID_PADDING = 10 @@ -25,8 +31,18 @@ CELL_COLOR_DICT = { 2:"#776e65", 4:"#776e65", 8:"#f9f6f2", 16:"#f9f6f2", \ FONT = ("Verdana", 40, "bold") class GameGrid(Frame): - def __init__(self): + def __init__(self,args=None): + for k in list(args.keys()): + if args[k] == None: + args.pop(k) + else : + args[k] = float(args[k]) + if "train" in args.keys(): + self.train = args["train"] + args.pop("train") + else: + self.train = TRAIN self.DISPLAY = True if self.DISPLAY: Frame.__init__(self) @@ -40,16 +56,16 @@ class GameGrid(Frame): self.reset() self.history = [] self.count = 0 - self.agent = afterstateAgent(self.matrix) - f = open("train_0.0025_0.0_result_after_2000.txt",'rb') + # self.agent = RandomAgent() + self.agent = afterstateAgent(self.matrix,**args) + f = open("train_0.0025_0.5_0.0_result_after_2000.txt",'rb') self.agent.W = pickle.load(f) - f.close() - print(self.agent.W[0]) + if self.DISPLAY: self.key_down() self.mainloop() else: - while self.count<=TRAIN: + while self.count<=self.train: self.key_down() def reset(self): @@ -93,6 +109,10 @@ class GameGrid(Frame): def key_down(self): + if self.count>=1: + self.agent.verbose = False + if self.agent.count >10000: + self.agent.verbose = True self.agent.set_state(self.matrix) key = self.agent.act() self.matrix,done = self.commands[key](self.matrix) @@ -102,27 +122,50 @@ class GameGrid(Frame): if self.DISPLAY: self.update_grid_cells() if done!=1: - reward = done + reward += done + # print(reward) # else: - # reward = -10 + # reward = -0.5 + + if game_state(self.matrix)=='win': print("win") + # self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY) + # self.grid_cells[1][2].configure(text="Win!",bg=BACKGROUND_COLOR_CELL_EMPTY) if game_state(self.matrix)=='lose': - print(np.max(self.matrix)) + if self.agent.explore>0: + print("explore: "+ str(self.agent.explore)) + # reward = -10 + # reward = np.log(np.max(self.matrix)) + # self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY) + # self.grid_cells[1][2].configure(text="Lose!",bg=BACKGROUND_COLOR_CELL_EMPTY) + print(str(self.count) + " : " + str(np.max(self.matrix))) # self.agent.update(self.matrix, reward) if (game_state(self.matrix)=='win' ) or (game_state(self.matrix)=='lose'): # print(self.agent.W) + if (self.count == self.train): + f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_result_after_"+str(self.count)+".txt",'wb') + pickle.dump(self.agent.W ,f) + f.close() + f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_history_after_"+str(self.count)+".txt",'wb') + np.savetxt(f, self.history) + f.close() self.history += [np.max(self.matrix)] self.agent.reset() self.count += 1 self.reset() + # plt.plot(self.history) + # plt.show() + # print(reward) + + # self.matrix if (self.DISPLAY): # Tell Tkinter to wait DELTA_TIME seconds before next iteration - self.after(100, self.key_down) + self.after(50, self.key_down) def generate_next(self): index = (self.gen(), self.gen()) @@ -130,6 +173,21 @@ class GameGrid(Frame): index = (self.gen(), self.gen()) self.matrix[index[0]][index[1]] = 2 -start_time = time.time() -gamegrid = GameGrid() -print("--- %s seconds ---" % (time.time() - start_time)) +if __name__ == '__main__': + parser = OptionParser() + parser.add_option("-g", "--TD", dest="TD_lambda", help ="TD_lambda the forget coefficient") + parser.add_option("-a", "--alpha", dest="alpha", help ="alpha the learning rate") + parser.add_option("-t", "--train", dest="train", help ="training episodes") + parser.add_option("-s", "--symmetric", dest="symmetric", help ="symmetric sampling") + parser.add_option("-e", "--epsilon", dest="epsilon", help ="epsilon the exploration") + parser.add_option("-u", "--tuple", dest="tuple", help ="the tuple to use") + (options,args)= parser.parse_args() + print(vars(options)) + f = open("train_0.0025_0.5_0.0_history_after_2000.txt",'rb') + history = np.loadtxt(f) + f.close() + plt.plot(history) + plt.show() + start_time = time.time() + gamegrid = GameGrid(vars(options)) + print("--- %s seconds ---" % (time.time() - start_time)) diff --git a/train_0.0025_0.0_True_result_after_2000.txt b/train_0.0025_0.0_True_result_after_2000.txt deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000