Skip to content
Snippets Groups Projects
Commit 4ad995b4 authored by Wen Yao Jin's avatar Wen Yao Jin
Browse files

Merge branch 'master' of gitlab.my.ecp.fr:2014jinwy/2048

parents 1fc6e468 8bc74b99
Branches
No related tags found
No related merge requests found
No preview for this file type
...@@ -11,21 +11,23 @@ class Action(IntEnum): ...@@ -11,21 +11,23 @@ class Action(IntEnum):
RIGHT = 4 RIGHT = 4
class afterstateAgent: class afterstateAgent:
def __init__(self, mat, TD_lambda = 0.0, alpha = 0.0025, gamma = 0.95, epsilon = 0.01, verbose= True, symmetric=True): def __init__(self, mat, TD_lambda = 0.0, alpha = 0.0025, gamma = 0.95, epsilon = 0.01, verbose= True, symmetric=1, tuple = 2):
self.state_per_tile = 12 self.state_per_tile = 12
self.commands = { Action.UP: up, Action.DOWN: down, Action.LEFT: left, Action.RIGHT: right} self.commands = { Action.UP: up, Action.DOWN: down, Action.LEFT: left, Action.RIGHT: right}
self.alpha = alpha self.alpha = alpha
# self.gamma = gamma # self.gamma = gamma
self.epsilon_origin = epsilon
self.epsilon = epsilon # e-greedy self.epsilon = epsilon # e-greedy
# self.TD_lambda = 1-epsilon # TD(lamdba) # self.TD_lambda = 1-epsilon # TD(lamdba)
self.TD_lambda = TD_lambda self.TD_lambda = TD_lambda
self.forget = self.TD_lambda self.forget = self.TD_lambda
self.symmetric = symmetric self.symmetric = symmetric
if self.symmetric: if tuple==0:
# self.tuple = self._tuple_advance() self.tuple = self._tuple()
elif tuple == 1:
self.tuple = self._tuple_advance() self.tuple = self._tuple_advance()
else: else:
self.tuple = self._tuple() self.tuple = self._tuple_advance_plus()
if verbose: if verbose:
print(len(self.tuple)) print(len(self.tuple))
self.W = self._generate_dict() self.W = self._generate_dict()
...@@ -58,7 +60,14 @@ class afterstateAgent: ...@@ -58,7 +60,14 @@ class afterstateAgent:
print(list) print(list)
return list return list
def _tuple_advance(self): def _tuple_advance(self):
return [[(0,0),(1,0),(2,0),(3,0)],\
[(0,1),(1,1),(2,1),(3,1)],\
[(0,1),(1,1),(2,1),(0,2),(1,2),(2,2)],\
[(0,2),(1,2),(2,2),(0,3),(1,3),(2,3)]]
def _tuple_advance_plus(self):
return [[(0,0),(1,0),(2,0),(3,0),(3,1),(2,1)],\ return [[(0,0),(1,0),(2,0),(3,0),(3,1),(2,1)],\
[(0,1),(1,1),(2,1),(3,1),(3,2),(2,2)],\ [(0,1),(1,1),(2,1),(3,1),(3,2),(2,2)],\
[(0,1),(1,1),(2,1),(0,2),(1,2),(2,2)],\ [(0,1),(1,1),(2,1),(0,2),(1,2),(2,2)],\
...@@ -69,6 +78,7 @@ class afterstateAgent: ...@@ -69,6 +78,7 @@ class afterstateAgent:
self.count = 0 self.count = 0
self.first_step = True# used to avoid update the first time self.first_step = True# used to avoid update the first time
self.explore = 0 self.explore = 0
self.epsilon -= self.epsilon_origin/2000
return return
def _reset_trace(self): def _reset_trace(self):
...@@ -160,6 +170,7 @@ class afterstateAgent: ...@@ -160,6 +170,7 @@ class afterstateAgent:
i = np.random.rand(); i = np.random.rand();
if i < self.epsilon: #explore if i < self.epsilon: #explore
self.explore += 1 self.explore += 1
self.forget = 0.0
return sum(phi_array) + 10000 return sum(phi_array) + 10000
return sum(phi_array) + done return sum(phi_array) + done
...@@ -189,19 +200,17 @@ class afterstateAgent: ...@@ -189,19 +200,17 @@ class afterstateAgent:
return return
s,r = self.test_next(self._action_index,self.state) s,r = self.test_next(self._action_index,self.state)
n = next_state n = next_state
if self.symmetric is True: if self.symmetric>0:
for i in range(4): for i in range(4):
s = transpose(s) s = transpose(s)
self.set_state(s)
n = transpose(n) n = transpose(n)
self.one_side_update(n,reward,s) self.one_side_update(n,reward,s)
s = reverse(s) s = reverse(s)
self.set_state(s)
n = reverse(n) n = reverse(n)
self.one_side_update(n,reward,s) self.one_side_update(n,reward,s)
#one loop is one rotation #one loop is one rotation
else: else:
one_side_update(next_state,reward,s) self.one_side_update(next_state,reward,s)
assert s==self.state, str(s)+str(self.state) assert s==self.state, str(s)+str(self.state)
return return
......
...@@ -177,6 +177,7 @@ if __name__ == '__main__': ...@@ -177,6 +177,7 @@ if __name__ == '__main__':
parser.add_option("-t", "--train", dest="train", help ="training episodes") parser.add_option("-t", "--train", dest="train", help ="training episodes")
parser.add_option("-s", "--symmetric", dest="symmetric", help ="symmetric sampling") parser.add_option("-s", "--symmetric", dest="symmetric", help ="symmetric sampling")
parser.add_option("-e", "--epsilon", dest="epsilon", help ="epsilon the exploration") parser.add_option("-e", "--epsilon", dest="epsilon", help ="epsilon the exploration")
parser.add_option("-u", "--tuple", dest="tuple", help ="the tuple to use")
(options,args)= parser.parse_args() (options,args)= parser.parse_args()
print(vars(options)) print(vars(options))
start_time = time.time() start_time = time.time()
......
import matplotlib
matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
from tkinter import * from tkinter import *
from logic import * from logic import *
from random import * from random import *
from agent import * from agent import *
from agent_afterstate import * from agent_afterstate import *
import numpy as np import numpy as np
import pickle import pickle
import time import time
import sys
from optparse import OptionParser
import os
TRAIN = 100000 TRAIN = 2000
SIZE = 500 SIZE = 500
GRID_LEN = 4 GRID_LEN = 4
GRID_PADDING = 10 GRID_PADDING = 10
...@@ -25,8 +31,18 @@ CELL_COLOR_DICT = { 2:"#776e65", 4:"#776e65", 8:"#f9f6f2", 16:"#f9f6f2", \ ...@@ -25,8 +31,18 @@ CELL_COLOR_DICT = { 2:"#776e65", 4:"#776e65", 8:"#f9f6f2", 16:"#f9f6f2", \
FONT = ("Verdana", 40, "bold") FONT = ("Verdana", 40, "bold")
class GameGrid(Frame): class GameGrid(Frame):
def __init__(self): def __init__(self,args=None):
for k in list(args.keys()):
if args[k] == None:
args.pop(k)
else :
args[k] = float(args[k])
if "train" in args.keys():
self.train = args["train"]
args.pop("train")
else:
self.train = TRAIN
self.DISPLAY = True self.DISPLAY = True
if self.DISPLAY: if self.DISPLAY:
Frame.__init__(self) Frame.__init__(self)
...@@ -40,16 +56,16 @@ class GameGrid(Frame): ...@@ -40,16 +56,16 @@ class GameGrid(Frame):
self.reset() self.reset()
self.history = [] self.history = []
self.count = 0 self.count = 0
self.agent = afterstateAgent(self.matrix) # self.agent = RandomAgent()
f = open("train_0.0025_0.0_result_after_2000.txt",'rb') self.agent = afterstateAgent(self.matrix,**args)
f = open("train_0.0025_0.5_0.0_result_after_2000.txt",'rb')
self.agent.W = pickle.load(f) self.agent.W = pickle.load(f)
f.close()
print(self.agent.W[0])
if self.DISPLAY: if self.DISPLAY:
self.key_down() self.key_down()
self.mainloop() self.mainloop()
else: else:
while self.count<=TRAIN: while self.count<=self.train:
self.key_down() self.key_down()
def reset(self): def reset(self):
...@@ -93,6 +109,10 @@ class GameGrid(Frame): ...@@ -93,6 +109,10 @@ class GameGrid(Frame):
def key_down(self): def key_down(self):
if self.count>=1:
self.agent.verbose = False
if self.agent.count >10000:
self.agent.verbose = True
self.agent.set_state(self.matrix) self.agent.set_state(self.matrix)
key = self.agent.act() key = self.agent.act()
self.matrix,done = self.commands[key](self.matrix) self.matrix,done = self.commands[key](self.matrix)
...@@ -102,27 +122,50 @@ class GameGrid(Frame): ...@@ -102,27 +122,50 @@ class GameGrid(Frame):
if self.DISPLAY: if self.DISPLAY:
self.update_grid_cells() self.update_grid_cells()
if done!=1: if done!=1:
reward = done reward += done
# print(reward)
# else: # else:
# reward = -10 # reward = -0.5
if game_state(self.matrix)=='win': if game_state(self.matrix)=='win':
print("win") print("win")
# self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY)
# self.grid_cells[1][2].configure(text="Win!",bg=BACKGROUND_COLOR_CELL_EMPTY)
if game_state(self.matrix)=='lose': if game_state(self.matrix)=='lose':
print(np.max(self.matrix)) if self.agent.explore>0:
print("explore: "+ str(self.agent.explore))
# reward = -10
# reward = np.log(np.max(self.matrix))
# self.grid_cells[1][1].configure(text="You",bg=BACKGROUND_COLOR_CELL_EMPTY)
# self.grid_cells[1][2].configure(text="Lose!",bg=BACKGROUND_COLOR_CELL_EMPTY)
print(str(self.count) + " : " + str(np.max(self.matrix)))
# self.agent.update(self.matrix, reward) # self.agent.update(self.matrix, reward)
if (game_state(self.matrix)=='win' ) or (game_state(self.matrix)=='lose'): if (game_state(self.matrix)=='win' ) or (game_state(self.matrix)=='lose'):
# print(self.agent.W) # print(self.agent.W)
if (self.count == self.train):
f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_result_after_"+str(self.count)+".txt",'wb')
pickle.dump(self.agent.W ,f)
f.close()
f = open("train_" +str(self.agent.alpha) +"_"+str(self.agent.TD_lambda)+"_"+str(self.agent.symmetric)+"_history_after_"+str(self.count)+".txt",'wb')
np.savetxt(f, self.history)
f.close()
self.history += [np.max(self.matrix)] self.history += [np.max(self.matrix)]
self.agent.reset() self.agent.reset()
self.count += 1 self.count += 1
self.reset() self.reset()
# plt.plot(self.history)
# plt.show()
# print(reward)
# self.matrix
if (self.DISPLAY): if (self.DISPLAY):
# Tell Tkinter to wait DELTA_TIME seconds before next iteration # Tell Tkinter to wait DELTA_TIME seconds before next iteration
self.after(100, self.key_down) self.after(50, self.key_down)
def generate_next(self): def generate_next(self):
index = (self.gen(), self.gen()) index = (self.gen(), self.gen())
...@@ -130,6 +173,21 @@ class GameGrid(Frame): ...@@ -130,6 +173,21 @@ class GameGrid(Frame):
index = (self.gen(), self.gen()) index = (self.gen(), self.gen())
self.matrix[index[0]][index[1]] = 2 self.matrix[index[0]][index[1]] = 2
if __name__ == '__main__':
parser = OptionParser()
parser.add_option("-g", "--TD", dest="TD_lambda", help ="TD_lambda the forget coefficient")
parser.add_option("-a", "--alpha", dest="alpha", help ="alpha the learning rate")
parser.add_option("-t", "--train", dest="train", help ="training episodes")
parser.add_option("-s", "--symmetric", dest="symmetric", help ="symmetric sampling")
parser.add_option("-e", "--epsilon", dest="epsilon", help ="epsilon the exploration")
parser.add_option("-u", "--tuple", dest="tuple", help ="the tuple to use")
(options,args)= parser.parse_args()
print(vars(options))
f = open("train_0.0025_0.5_0.0_history_after_2000.txt",'rb')
history = np.loadtxt(f)
f.close()
plt.plot(history)
plt.show()
start_time = time.time() start_time = time.time()
gamegrid = GameGrid() gamegrid = GameGrid(vars(options))
print("--- %s seconds ---" % (time.time() - start_time)) print("--- %s seconds ---" % (time.time() - start_time))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment