MCTS.py
2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import numpy as np
import logging
import config
class Node(object):
def __init__(self, state):
self.state = state
self.id = state.id
self.edges = []
def is_leaf(self):
if len(self.edges) > 0:
return False
else:
return True
class Edge(object):
def __init__(self, in_node, out_node, prior, action):
self.id = in_node.state.id + '|' + out_node.state.id
self.inNode = in_node
self.outNode = out_node
self.action = action
self.stats = {
'N': 0,
'W': 0,
'Q': 0,
'P': prior,
}
# N 이 횟수
# W 가 win
# Q 가 승률
# P 가
class MCTS(object):
def __init__(self, root, cpuct):
self.root = root
self.tree = {}
self.cpuct = cpuct
self.add_node(root)
def __len__(self):
return len(self.tree)
def moveToLeaf(self):
breadcrumbs = []
current_node = self.root
done = 0
value = 0
while not current_node.is_leaf():
max_qu = -99999
if current_node == self.root:
epsilon = config.EPSILON
nu = np.random.dirichlet([config.ALPHA] * len(current_node.edges))
else:
epsilon = 0
nu = [0] * len(current_node.edges)
Nb = 0
for action, edge in current_node.edges:
Nb = Nb + edge.stats['N']
for idx, (action, edge) in enumerate(current_node.edges):
U = self.cpuct * \
((1 - epsilon) * edge.stats['P'] + epsilon * nu[idx]) * \
np.sqrt(Nb) / (1 + edge.stats['N'])
Q = edge.stats['Q']
if Q + U > max_qu:
max_qu = Q + U
simulation_action = action
simulation_edge = edge
new_state, value, done = current_node.state.takeAction(simulation_action)
current_node = simulation_edge.outNode
breadcrumbs.append(simulation_edge)
return current_node, value, done, breadcrumbs
def back_fill(self, value, breadcrumbs):
for edge in breadcrumbs:
edge.stats['N'] = edge.stats['N'] + 1
edge.stats['W'] = edge.stats['W'] + value
edge.stats['Q'] = edge.stats['W'] / edge.stats['N']
def add_node(self, node):
self.tree[node.id] = node