MCTS.py 2.42 KB
import numpy as np
import logging
import config


class Node(object):

    def __init__(self, state):
        self.state = state
        self.id = state.id
        self.edges = []

    def is_leaf(self):
        if len(self.edges) > 0:
            return False
        else:
            return True


class Edge(object):

    def __init__(self, in_node, out_node, prior, action):
        self.id = in_node.state.id + '|' + out_node.state.id
        self.inNode = in_node
        self.outNode = out_node
        self.action = action

        self.stats = {
            'N': 0,
            'W': 0,
            'Q': 0,
            'P': prior,
        }


# N 이 횟수
# W 가 win
# Q 가 승률
# P 가

class MCTS(object):

    def __init__(self, root, cpuct):
        self.root = root
        self.tree = {}
        self.cpuct = cpuct
        self.add_node(root)

    def __len__(self):
        return len(self.tree)

    def moveToLeaf(self):

        breadcrumbs = []
        current_node = self.root

        done = 0
        value = 0

        while not current_node.is_leaf():

            max_qu = -99999

            if current_node == self.root:
                epsilon = config.EPSILON
                nu = np.random.dirichlet([config.ALPHA] * len(current_node.edges))
            else:
                epsilon = 0
                nu = [0] * len(current_node.edges)

            Nb = 0
            for action, edge in current_node.edges:
                Nb = Nb + edge.stats['N']

            for idx, (action, edge) in enumerate(current_node.edges):

                U = self.cpuct * \
                    ((1 - epsilon) * edge.stats['P'] + epsilon * nu[idx]) * \
                    np.sqrt(Nb) / (1 + edge.stats['N'])

                Q = edge.stats['Q']

                if Q + U > max_qu:
                    max_qu = Q + U
                    simulation_action = action
                    simulation_edge = edge

            new_state, value, done = current_node.state.takeAction(simulation_action)
            current_node = simulation_edge.outNode
            breadcrumbs.append(simulation_edge)

        return current_node, value, done, breadcrumbs

    def back_fill(self, value, breadcrumbs):
        for edge in breadcrumbs:
            edge.stats['N'] = edge.stats['N'] + 1
            edge.stats['W'] = edge.stats['W'] + value
            edge.stats['Q'] = edge.stats['W'] / edge.stats['N']

    def add_node(self, node):
        self.tree[node.id] = node