main.py 2.18 KB
import pickle
import random

from DNN.model import ResidualCNN
from agent import Agent
import config
import os


def playMatches(web, EPISODES, turns_until_tau0, memory=None):
    # scores = {player1.name:0, "drawn": 0, player2.name:0}
    # sp_scores = {'sp':0, "drawn": 0, 'nsp':0}
    # points = {player1.name:[], player2.name:[]}

    for e in range(EPISODES):

        state = web.reset()

        done = 0
        web.mcts = None
        turn = 0
        while done == 0:
            turn = turn + 1

            #### Run the MCTS algo and return an action

            action, pi, value = web.act(state, 1)

            if memory is not None:
                # Commit the move to memory
                memory.commit_stmemory(state, pi)

            # Do the action
            state, value, done, _ = web.step(action)

            if done == 1:
                if memory is not None:
                    # If the game is finished, assign the values correctly to the game moves
                    for move in memory.stmemory:
                        move['value'] = value

                print('done')
                memory.commit_ltmemory()
        return memory


iteration = 0
cnn = ResidualCNN(config.REG_CONST, config.LEARNING_RATE, config.INPUT_SHAPE, config.NUM_ACTIONS,
                  config.HIDDEN_CNN_LAYERS)
current_player = Agent('current_player', 20, config.NUM_ACTIONS, config.MCTS_SIMS, config.CPUCT, cnn)
memory = None
run_folder = './run/'
if not os.path.exists(run_folder):
    os.mkdir(run_folder)

while 1:
    iteration += 1
    print('ITERATION NUMBER ' + str(iteration))
    memory = playMatches(current_player, config.EPISODES, turns_until_tau0=config.TURNS_UNTIL_TAU0, memory=memory)
    memory.clear_stmemory()

    if len(memory.ltmemory) >= config.MEMORY_SIZE:

        ######## RETRAINING ########
        print('RETRAINING...')
        current_player.replay(memory.ltmemory)
        print('')

        if iteration % 5 == 0:
            pickle.dump(memory, open(run_folder + "memory/memory" + str(iteration).zfill(4) + ".p", "wb"))

        memory_samp = random.sample(memory.ltmemory, min(1000, len(memory.ltmemory)))

    else:
        print('MEMORY SIZE: ' + str(len(memory.ltmemory)))