Skip to content
Snippets Groups Projects
Commit 31f47717 authored by Wen Yao Jin's avatar Wen Yao Jin
Browse files

go

parents a0e59046 8268ec83
No related branches found
No related tags found
No related merge requests found
No preview for this file type
No preview for this file type
No preview for this file type
...@@ -11,11 +11,12 @@ class Action(IntEnum): ...@@ -11,11 +11,12 @@ class Action(IntEnum):
RIGHT = 4 RIGHT = 4
class afterstateAgent: class afterstateAgent:
def __init__(self, mat, TD_lambda = 0.0, alpha = 0.0025, gamma = 0.95, epsilon = 0.01, verbose= True, symmetric=True, tuple = 2): def __init__(self, mat, TD_lambda = 0.0, alpha = 0.0025, gamma = 0.95, epsilon = 0.01, verbose= True, symmetric=1, tuple = 2):
self.state_per_tile = 12 self.state_per_tile = 12
self.commands = { Action.UP: up, Action.DOWN: down, Action.LEFT: left, Action.RIGHT: right} self.commands = { Action.UP: up, Action.DOWN: down, Action.LEFT: left, Action.RIGHT: right}
self.alpha = alpha self.alpha = alpha
# self.gamma = gamma # self.gamma = gamma
self.epsilon_origin = epsilon
self.epsilon = epsilon # e-greedy self.epsilon = epsilon # e-greedy
# self.TD_lambda = 1-epsilon # TD(lamdba) # self.TD_lambda = 1-epsilon # TD(lamdba)
self.TD_lambda = TD_lambda self.TD_lambda = TD_lambda
...@@ -77,6 +78,7 @@ class afterstateAgent: ...@@ -77,6 +78,7 @@ class afterstateAgent:
self.count = 0 self.count = 0
self.first_step = True# used to avoid update the first time self.first_step = True# used to avoid update the first time
self.explore = 0 self.explore = 0
self.epsilon -= self.epsilon_origin/2000
return return
def _reset_trace(self): def _reset_trace(self):
...@@ -168,6 +170,7 @@ class afterstateAgent: ...@@ -168,6 +170,7 @@ class afterstateAgent:
i = np.random.rand(); i = np.random.rand();
if i < self.epsilon: #explore if i < self.epsilon: #explore
self.explore += 1 self.explore += 1
self.forget = 0.0
return sum(phi_array) + 10000 return sum(phi_array) + 10000
return sum(phi_array) + done return sum(phi_array) + done
...@@ -197,19 +200,17 @@ class afterstateAgent: ...@@ -197,19 +200,17 @@ class afterstateAgent:
return return
s,r = self.test_next(self._action_index,self.state) s,r = self.test_next(self._action_index,self.state)
n = next_state n = next_state
if self.symmetric is True: if self.symmetric>0:
for i in range(4): for i in range(4):
s = transpose(s) s = transpose(s)
self.set_state(s)
n = transpose(n) n = transpose(n)
self.one_side_update(n,reward,s) self.one_side_update(n,reward,s)
s = reverse(s) s = reverse(s)
self.set_state(s)
n = reverse(n) n = reverse(n)
self.one_side_update(n,reward,s) self.one_side_update(n,reward,s)
#one loop is one rotation #one loop is one rotation
else: else:
one_side_update(next_state,reward,s) self.one_side_update(next_state,reward,s)
assert s==self.state, str(s)+str(self.state) assert s==self.state, str(s)+str(self.state)
return return
......
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from tkinter import * from tkinter import *
from logic import * from logic import *
from random import * from random import *
from agent import * from agent import *
from agent_afterstate import * from agent_afterstate import *
import numpy as np import numpy as np
import pickle import pickle
import time import time
import sys
from optparse import OptionParser
import os
TRAIN = 100000 TRAIN = 2000
SIZE = 500 SIZE = 500
GRID_LEN = 4 GRID_LEN = 4
GRID_PADDING = 10 GRID_PADDING = 10
...@@ -25,8 +31,18 @@ CELL_COLOR_DICT = { 2:"#776e65", 4:"#776e65", 8:"#f9f6f2", 16:"#f9f6f2", \ ...@@ -25,8 +31,18 @@ CELL_COLOR_DICT = { 2:"#776e65", 4:"#776e65", 8:"#f9f6f2", 16:"#f9f6f2", \
FONT = ("Verdana", 40, "bold") FONT = ("Verdana", 40, "bold")
class GameGrid(Frame): class GameGrid(Frame):
def __init__(self): def __init__(self,args=None):
for k in list(args.keys()):
if args[k] == None:
args.pop(k)
else :
args[k] = float(args[k])
if "train" in args.keys():
self.train = args["train"]
args.pop("train")
else:
self.train = TRAIN
self.DISPLAY = True self.DISPLAY = True
if self.DISPLAY: if self.DISPLAY:
Frame.__init__(self) Frame.__init__(self)
...@@ -40,16 +56,16 @@ class GameGrid(Frame): ...@@ -40,16 +56,16 @@ class GameGrid(Frame):
self.reset() self.reset()
self.history = [] self.history = []
self.count = 0 self.count = 0
self.agent = afterstateAgent(self.matrix) # self.agent = RandomAgent()
f = open("train_0.0025_0.0_result_after_2000.txt",'rb') self.agent = afterstateAgent(self.matrix,**args)
f = open("train_0.0025_0.5_0.0_result_after_2000.txt",'rb')
self.agent.W = pickle.load(f) self.agent.W = pickle.load(f)
f.close()
print(self.agent.W[0])
if self.DISPLAY: if self.DISPLAY:
self.key_down() self.key_down()
self.mainloop() self.mainloop()
else: else:
while self.count<=TRAIN: while self.count<=self.train:
self.key_down() self.key_down()
def reset(self): def reset(self):
...@@ -93,6 +109,10 @@ class GameGrid(Frame): ...@@ -93,6 +109,10 @@ class GameGrid(Frame):
def key_down(self): def key_down(self):
if self.count>=1:
self.agent.verbose = False
if self.agent.count >10000:
self.agent.verbose = True
self.agent.set_state(self.matrix) self.agent.set_state(self.matrix)
key = self.agent.act() key = self.agent.act()
self.matrix,done = self.commands[key](self.matrix) self.matrix,done = self.commands[key](self.matrix)
...@@ -102,27 +122,50 @@ class GameGrid(Frame): ...@@ -102,27 +122,50 @@ class GameGrid(Frame):
if self.DISPLAY: if self.DISPLAY:
self.update_grid_cells() self.update_grid_cells()
if done!=1: if done!=1:
reward = done reward += done
# print(reward)
# else: # else:
# reward = -10 # reward = -0.5
if game_state(self.matrix)=='win': if game_state(self.matrix)=='win':
print("win") print("win")
# self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY)
# self.grid_cells[1][2].configure(text="Win!",bg=BACKGROUND_COLOR_CELL_EMPTY)
if game_state(self.matrix)=='lose': if game_state(self.matrix)=='lose':
print(np.max(self.matrix)) if self.agent.explore>0:
print("explore: "+ str(self.agent.explore))
# reward = -10
# reward = np.log(np.max(self.matrix))
# self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY)
# self.grid_cells[1][2].configure(text="Lose!",bg=BACKGROUND_COLOR_CELL_EMPTY)
print(str(self.count) + " : " + str(np.max(self.matrix)))
# self.agent.update(self.matrix, reward) # self.agent.update(self.matrix, reward)
if (game_state(self.matrix)=='win' ) or (game_state(self.matrix)=='lose'): if (game_state(self.matrix)=='win' ) or (game_state(self.matrix)=='lose'):
# print(self.agent.W) # print(self.agent.W)
if (self.count == self.train):
f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_result_after_"+str(self.count)+".txt",'wb')
pickle.dump(self.agent.W ,f)
f.close()
f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_history_after_"+str(self.count)+".txt",'wb')
np.savetxt(f, self.history)
f.close()
self.history += [np.max(self.matrix)] self.history += [np.max(self.matrix)]
self.agent.reset() self.agent.reset()
self.count += 1 self.count += 1
self.reset() self.reset()
# plt.plot(self.history)
# plt.show()
# print(reward)
# self.matrix
if (self.DISPLAY): if (self.DISPLAY):
# Tell Tkinter to wait DELTA_TIME seconds before next iteration # Tell Tkinter to wait DELTA_TIME seconds before next iteration
self.after(100, self.key_down) self.after(50, self.key_down)
def generate_next(self): def generate_next(self):
index = (self.gen(), self.gen()) index = (self.gen(), self.gen())
...@@ -130,6 +173,21 @@ class GameGrid(Frame): ...@@ -130,6 +173,21 @@ class GameGrid(Frame):
index = (self.gen(), self.gen()) index = (self.gen(), self.gen())
self.matrix[index[0]][index[1]] = 2 self.matrix[index[0]][index[1]] = 2
if __name__ == '__main__':
parser = OptionParser()
parser.add_option("-g", "--TD", dest="TD_lambda", help ="TD_lambda the forget coefficient")
parser.add_option("-a", "--alpha", dest="alpha", help ="alpha the learning rate")
parser.add_option("-t", "--train", dest="train", help ="training episodes")
parser.add_option("-s", "--symmetric", dest="symmetric", help ="symmetric sampling")
parser.add_option("-e", "--epsilon", dest="epsilon", help ="epsilon the exploration")
parser.add_option("-u", "--tuple", dest="tuple", help ="the tuple to use")
(options,args)= parser.parse_args()
print(vars(options))
f = open("train_0.0025_0.5_0.0_history_after_2000.txt",'rb')
history = np.loadtxt(f)
f.close()
plt.plot(history)
plt.show()
start_time = time.time() start_time = time.time()
gamegrid = GameGrid() gamegrid = GameGrid(vars(options))
print("--- %s seconds ---" % (time.time() - start_time)) print("--- %s seconds ---" % (time.time() - start_time))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment