main.py
2.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import pickle
import random
from DNN.model import ResidualCNN
from agent import Agent
import config
import os
def playMatches(web, EPISODES, turns_until_tau0, memory=None):
# scores = {player1.name:0, "drawn": 0, player2.name:0}
# sp_scores = {'sp':0, "drawn": 0, 'nsp':0}
# points = {player1.name:[], player2.name:[]}
for e in range(EPISODES):
state = web.reset()
done = 0
web.mcts = None
turn = 0
while done == 0:
turn = turn + 1
#### Run the MCTS algo and return an action
action, pi, value = web.act(state, 1)
if memory is not None:
# Commit the move to memory
memory.commit_stmemory(state, pi)
# Do the action
state, value, done, _ = web.step(action)
if done == 1:
if memory is not None:
# If the game is finished, assign the values correctly to the game moves
for move in memory.stmemory:
move['value'] = value
print('done')
memory.commit_ltmemory()
return memory
iteration = 0
cnn = ResidualCNN(config.REG_CONST, config.LEARNING_RATE, config.INPUT_SHAPE, config.NUM_ACTIONS,
config.HIDDEN_CNN_LAYERS)
current_player = Agent('current_player', 20, config.NUM_ACTIONS, config.MCTS_SIMS, config.CPUCT, cnn)
memory = None
run_folder = './run/'
if not os.path.exists(run_folder):
os.mkdir(run_folder)
while 1:
iteration += 1
print('ITERATION NUMBER ' + str(iteration))
memory = playMatches(current_player, config.EPISODES, turns_until_tau0=config.TURNS_UNTIL_TAU0, memory=memory)
memory.clear_stmemory()
if len(memory.ltmemory) >= config.MEMORY_SIZE:
######## RETRAINING ########
print('RETRAINING...')
current_player.replay(memory.ltmemory)
print('')
if iteration % 5 == 0:
pickle.dump(memory, open(run_folder + "memory/memory" + str(iteration).zfill(4) + ".p", "wb"))
memory_samp = random.sample(memory.ltmemory, min(1000, len(memory.ltmemory)))
else:
print('MEMORY SIZE: ' + str(len(memory.ltmemory)))