seongahjo

Add 코드

No preview for this file type
No preview for this file type
This file is too large to display.
1 +# Byte-compiled / optimized / DLL files
2 +__pycache__/
3 +*.py[cod]
4 +*$py.class
5 +.idea
6 +
7 +# C extensions
8 +*.so
9 +
10 +# Distribution / packaging
11 +.Python
12 +build/
13 +develop-eggs/
14 +dist/
15 +downloads/
16 +eggs/
17 +.eggs/
18 +lib/
19 +lib64/
20 +parts/
21 +sdist/
22 +var/
23 +wheels/
24 +*.egg-info/
25 +.installed.cfg
26 +*.egg
27 +MANIFEST
28 +
29 +# PyInstaller
30 +# Usually these files are written by a python script from a template
31 +# before PyInstaller builds the exe, so as to inject date/other infos into it.
32 +*.manifest
33 +*.spec
34 +
35 +# Installer logs
36 +pip-log.txt
37 +pip-delete-this-directory.txt
38 +
39 +# Unit test / coverage reports
40 +htmlcov/
41 +.tox/
42 +.coverage
43 +.coverage.*
44 +.cache
45 +nosetests.xml
46 +coverage.xml
47 +*.cover
48 +.hypothesis/
49 +.pytest_cache/
50 +
51 +# Translations
52 +*.mo
53 +*.pot
54 +
55 +# Django stuff:
56 +*.log
57 +local_settings.py
58 +db.sqlite3
59 +
60 +# Flask stuff:
61 +instance/
62 +.webassets-cache
63 +
64 +# Scrapy stuff:
65 +.scrapy
66 +
67 +# Sphinx documentation
68 +docs/_build/
69 +
70 +# PyBuilder
71 +target/
72 +
73 +# Jupyter Notebook
74 +.ipynb_checkpoints
75 +
76 +# pyenv
77 +.python-version
78 +
79 +# celery beat schedule file
80 +celerybeat-schedule
81 +
82 +# SageMath parsed files
83 +*.sage.py
84 +
85 +# Environments
86 +.env
87 +.venv
88 +env/
89 +venv/
90 +ENV/
91 +env.bak/
92 +venv.bak/
93 +
94 +# Spyder project settings
95 +.spyderproject
96 +.spyproject
97 +
98 +# Rope project settings
99 +.ropeproject
100 +
101 +# mkdocs documentation
102 +/site
103 +
104 +# mypy
105 +.mypy_cache/
File mode changed
1 +
2 +import tensorflow as tf
3 +
4 +def softmax_cross_entropy_with_logits(y_true, y_pred):
5 +
6 + p = y_pred
7 + pi = y_true
8 +
9 + zero = tf.zeros(shape = tf.shape(pi), dtype=tf.float32)
10 + where = tf.equal(pi, zero)
11 +
12 + negatives = tf.fill(tf.shape(pi), -100.0)
13 + p = tf.where(where, negatives, p)
14 +
15 + loss = tf.nn.softmax_cross_entropy_with_logits(labels = pi, logits = p)
16 +
17 + return loss
18 +
19 +
20 +
21 +
22 +
23 +
1 +import config
2 +import numpy as np
3 +
4 +from keras.models import load_model, Model
5 +from keras.layers import Input, Dense, Conv1D, Flatten, BatchNormalization, LeakyReLU, add
6 +from keras.optimizers import SGD
7 +from keras import regularizers
8 +
9 +from DNN.loss import softmax_cross_entropy_with_logits
10 +
11 +
12 +class GeneralModel(object):
13 + def __init__(self, reg_const, learning_rate, input_dim, output_dim):
14 + self.reg_const = reg_const
15 + self.learning_rate = learning_rate
16 + self.input_dim = input_dim
17 + self.output_dim = output_dim
18 +
19 + def predict(self, x):
20 + return self.model.predict(x)
21 +
22 + def fit(self, states, targets, epochs, verbose, validation_split, batch_size):
23 + return self.model.fit(states, targets, epochs=epochs, verbose=verbose, validation_split=validation_split,
24 + batch_size=batch_size)
25 +
26 + def write(self, game, version):
27 + self.model.save('models/version' + "{0:0>4}".format(version) + '.h5')
28 +
29 + def read(self, game, run_number, version):
30 + return load_model(str(run_number).zfill(4) + "/models/version" + "{0:0>4}".format
31 + (version) + '.h5', custom_objects={'softmax_cross_entropy_with_logits': softmax_cross_entropy_with_logits})
32 +
33 +
34 +class ResidualCNN(GeneralModel):
35 + def __init__(self, reg_const, learning_rate, input_dim, output_dim, hidden_layers):
36 + GeneralModel.__init__(self, reg_const, learning_rate, input_dim, output_dim)
37 + self.hidden_layers = hidden_layers
38 + self.num_layers = len(hidden_layers)
39 + self.model = self._build_model()
40 +
41 + def residual_layer(self, input_block, filters, kernel_size):
42 +
43 + x = self.conv_layer(input_block, filters, kernel_size)
44 +
45 + x = Conv1D(
46 + filters=filters
47 + , kernel_size=kernel_size
48 + , data_format="channels_last"
49 + , padding='same'
50 + , use_bias=False
51 + , activation='linear'
52 + , kernel_regularizer=regularizers.l2(self.reg_const)
53 + )(x)
54 +
55 + x = BatchNormalization(axis=1)(x)
56 +
57 + x = add([input_block, x])
58 +
59 + x = LeakyReLU()(x)
60 +
61 + return x
62 +
63 + def conv_layer(self, x, filters, kernel_size):
64 +
65 + x = Conv1D(
66 + filters=filters
67 + , kernel_size=kernel_size
68 + , data_format="channels_last"
69 + , padding='same'
70 + , use_bias=False
71 + , activation='linear'
72 + , kernel_regularizer=regularizers.l2(self.reg_const)
73 + )(x)
74 +
75 + x = BatchNormalization(axis=1)(x)
76 + x = LeakyReLU()(x)
77 +
78 + return x
79 +
80 + def value_head(self, x):
81 + x = Conv1D(
82 + filters=2
83 + , kernel_size=1
84 + , data_format="channels_last"
85 + , padding='same'
86 + , use_bias=False
87 + , activation='linear'
88 + , kernel_regularizer=regularizers.l2(self.reg_const)
89 + )(x)
90 +
91 + x = BatchNormalization(axis=1)(x)
92 + x = LeakyReLU()(x)
93 +
94 + x = Flatten()(x)
95 +
96 + x = Dense(
97 + self.output_dim
98 + , use_bias=False
99 + , activation='linear'
100 + , kernel_regularizer=regularizers.l2(self.reg_const)
101 + , name='value_head'
102 + )(x)
103 +
104 + return x
105 +
106 + def policy_head(self, x):
107 + x = Conv1D(
108 + filters=2
109 + , kernel_size=1
110 + , data_format="channels_last"
111 + , padding='same'
112 + , use_bias=False
113 + , activation='linear'
114 + , kernel_regularizer=regularizers.l2(self.reg_const)
115 + )(x)
116 +
117 + x = BatchNormalization(axis=1)(x)
118 + x = LeakyReLU()(x)
119 +
120 + x = Flatten()(x)
121 +
122 + x = Dense(
123 + self.output_dim
124 + , use_bias=False
125 + , activation='linear'
126 + , kernel_regularizer=regularizers.l2(self.reg_const)
127 + , name='policy_head'
128 + )(x)
129 +
130 + return x
131 +
132 + def _build_model(self):
133 + # image shape
134 + # 그냥 배열 shape
135 +
136 + main_input = Input(shape=self.input_dim, name='main_input')
137 +
138 + x = self.conv_layer(main_input, self.hidden_layers[0]['filters'], self.hidden_layers[0]['kernel_size'])
139 +
140 + if len(self.hidden_layers) > 1:
141 + for h in self.hidden_layers[1:]:
142 + x = self.residual_layer(x, h['filters'], h['kernel_size'])
143 +
144 + vh = self.value_head(x)
145 + ph = self.policy_head(x)
146 +
147 + model = Model(inputs=[main_input], outputs=[ph])
148 + model.compile(loss=softmax_cross_entropy_with_logits,
149 + optimizer=SGD(lr=self.learning_rate, momentum=config.MOMENTUM),
150 + loss_weights=[0.5]
151 + )
152 +
153 + return model
154 +
155 + def convertToModelInput(self, state):
156 + # [20,20]
157 + inputToModel = state.state_check # np.append(state.binary, [(state.playerTurn + 1)/2] * self.input_dim[1] * self.input_dim[2])
158 + inputToModel = np.reshape(inputToModel, self.input_dim)
159 + return inputToModel
1 +MIT License
2 +
3 +Copyright (c) 2019 seongahjo
4 +
5 +Permission is hereby granted, free of charge, to any person obtaining a copy
6 +of this software and associated documentation files (the "Software"), to deal
7 +in the Software without restriction, including without limitation the rights
8 +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 +copies of the Software, and to permit persons to whom the Software is
10 +furnished to do so, subject to the following conditions:
11 +
12 +The above copyright notice and this permission notice shall be included in all
13 +copies or substantial portions of the Software.
14 +
15 +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 +SOFTWARE.
1 +import numpy as np
2 +import logging
3 +import config
4 +
5 +
6 +class Node(object):
7 +
8 + def __init__(self, state):
9 + self.state = state
10 + self.id = state.id
11 + self.edges = []
12 +
13 + def is_leaf(self):
14 + if len(self.edges) > 0:
15 + return False
16 + else:
17 + return True
18 +
19 +
20 +class Edge(object):
21 +
22 + def __init__(self, in_node, out_node, prior, action):
23 + self.id = in_node.state.id + '|' + out_node.state.id
24 + self.inNode = in_node
25 + self.outNode = out_node
26 + self.action = action
27 +
28 + self.stats = {
29 + 'N': 0,
30 + 'W': 0,
31 + 'Q': 0,
32 + 'P': prior,
33 + }
34 +
35 +
36 +# N 이 횟수
37 +# W 가 win
38 +# Q 가 승률
39 +# P 가
40 +
41 +class MCTS(object):
42 +
43 + def __init__(self, root, cpuct):
44 + self.root = root
45 + self.tree = {}
46 + self.cpuct = cpuct
47 + self.add_node(root)
48 +
49 + def __len__(self):
50 + return len(self.tree)
51 +
52 + def moveToLeaf(self):
53 +
54 + breadcrumbs = []
55 + current_node = self.root
56 +
57 + done = 0
58 + value = 0
59 +
60 + while not current_node.is_leaf():
61 +
62 + max_qu = -99999
63 +
64 + if current_node == self.root:
65 + epsilon = config.EPSILON
66 + nu = np.random.dirichlet([config.ALPHA] * len(current_node.edges))
67 + else:
68 + epsilon = 0
69 + nu = [0] * len(current_node.edges)
70 +
71 + Nb = 0
72 + for action, edge in current_node.edges:
73 + Nb = Nb + edge.stats['N']
74 +
75 + for idx, (action, edge) in enumerate(current_node.edges):
76 +
77 + U = self.cpuct * \
78 + ((1 - epsilon) * edge.stats['P'] + epsilon * nu[idx]) * \
79 + np.sqrt(Nb) / (1 + edge.stats['N'])
80 +
81 + Q = edge.stats['Q']
82 +
83 + if Q + U > max_qu:
84 + max_qu = Q + U
85 + simulation_action = action
86 + simulation_edge = edge
87 +
88 + new_state, value, done = current_node.state.takeAction(simulation_action)
89 + current_node = simulation_edge.outNode
90 + breadcrumbs.append(simulation_edge)
91 +
92 + return current_node, value, done, breadcrumbs
93 +
94 + def back_fill(self, value, breadcrumbs):
95 + for edge in breadcrumbs:
96 + edge.stats['N'] = edge.stats['N'] + 1
97 + edge.stats['W'] = edge.stats['W'] + value
98 + edge.stats['Q'] = edge.stats['W'] / edge.stats['N']
99 +
100 + def add_node(self, node):
101 + self.tree[node.id] = node
1 +## D.sAIgn
2 +Web Design Automation from screenshot by Reinforcement Learning
3 +
4 +**Now only train is availble**
5 +## Feature
6 +* Monte Carlo Tree Search
7 +* DQN
8 +
9 +## Installation
10 +* `Python 3`
11 +
12 +
13 +## Usage
14 +* configure `config.py`
15 +* `python main.py`
16 +
17 +## Roadmap
18 +- [ ] run D.sAIgn using trained model
19 +- [ ] see remarkable result
20 +- [ ] fast train by multiprocessing
21 +- [ ] find majorly used css properties
22 +- [ ] offer pretrained model
23 +
24 +## License
25 +D.sAIgn is released under [MIT License]
26 +
27 +[MIT License]:https://github.com/seongahjo/D.sAIgn/blob/master/LICENSE
28 +
...\ No newline at end of file ...\ No newline at end of file
1 +__all__ = ['config']
...\ No newline at end of file ...\ No newline at end of file
1 +# %matplotlib inline
2 +
3 +import numpy as np
4 +import random
5 +import config
6 +import MCTS as mc
7 +
8 +import time
9 +
10 +from browser import Chrome
11 +from similarity import Similarity
12 +from state import State
13 +
14 +
15 +class Agent():
16 + def __init__(self, name, state_size, action_size, mcts_simulations, cpuct, model=None):
17 + self.name = name
18 + self.action_size = action_size
19 +
20 + self.cpuct = cpuct
21 + self.similarity = Similarity('result.png')
22 + self.MCTSsimulations = mcts_simulations
23 + self.model = model
24 + self.browser = Chrome('kakao.html')
25 + self.mcts = None
26 + self.gameState = State([set() for i in range(20)],
27 + 0,
28 + np.ones((20, config.NUM_ACTIONS), dtype=np.int),
29 + ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
30 + 'main_summary', 'service_info', 'ico_corp', 'tit_black',
31 + 'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
32 + 'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
33 + self.similarity)
34 +
35 + self.state_size = 20
36 + self.train_overall_loss = []
37 + self.train_value_loss = []
38 + self.train_policy_loss = []
39 + self.val_overall_loss = []
40 + self.val_value_loss = []
41 + self.val_policy_loss = []
42 +
43 + def simulate(self):
44 +
45 + ##### MOVE THE LEAF NODE
46 + leaf, value, done, breadcrumbs = self.mcts.moveToLeaf()
47 + ##### EVALUATE THE LEAF NODE
48 + value, breadcrumbs = self.evaluateLeaf(leaf, value, done, breadcrumbs)
49 + ##### BACKFILL THE VALUE THROUGH THE TREE
50 + self.mcts.back_fill(value, breadcrumbs)
51 +
52 + def act(self, state, tau):
53 +
54 + if self.mcts is None or state.id not in self.mcts.tree:
55 + self.buildMCTS(state)
56 + else:
57 + self.changeRootMCTS(state)
58 +
59 + #### run the simulation
60 + for sim in range(self.MCTSsimulations):
61 + self.simulate()
62 +
63 + #### get action values
64 + pi, values = self.getAV(1)
65 +
66 + ####pick the action
67 + action, value = self.chooseAction(pi, values, tau)
68 + print('action {} '.format(action))
69 + nextState, _, _ = state.takeAction(action)
70 +
71 + return action, pi, value
72 +
73 + def get_preds(self, state):
74 + # predict the leaf
75 + inputToModel = np.array([self.model.convertToModelInput(state)])
76 + preds = self.model.predict(inputToModel)
77 + logits_array = preds
78 + logits = logits_array[0]
79 + allowedActions = state.allowedActions
80 +
81 + mask = np.ones(config.NUM_ACTIONS, dtype=bool)
82 + mask[allowedActions] = False
83 +
84 + # sim = np.zeros(config.NUM_ACTIONS, dtype=np.float)
85 + # for act in allowedActions:
86 + # new_state, _, _ = state.takeAction(act)
87 + # sim[act] += new_state.similarity * 100
88 + # logits = sim
89 + # print('sim : {} '.format(logits))
90 + # print('probs : {} '.format(probs))
91 + # probs = np.ones(config.NUM_ACTIONS)
92 +
93 + logits[mask] = -100
94 +
95 + # SOFTMAX
96 + odds = np.exp(logits)
97 + probs = odds / np.sum(odds) ###put this just before the for?
98 +
99 + return probs, allowedActions
100 +
101 + def evaluateLeaf(self, leaf, value, done, breadcrumbs):
102 +
103 + if done == 0:
104 +
105 + probs, allowedActions = self.get_preds(leaf.state)
106 + probs = probs[allowedActions]
107 +
108 + for idx, act in enumerate(allowedActions):
109 + newState, _, _ = leaf.state.takeAction(act)
110 + if newState.id not in self.mcts.tree:
111 + node = mc.Node(newState)
112 + self.mcts.add_node(node)
113 + else:
114 + node = self.mcts.tree[newState.id]
115 +
116 + newEdge = mc.Edge(leaf, node, probs[idx], act)
117 + leaf.edges.append((act, newEdge))
118 +
119 + return (value, breadcrumbs)
120 +
121 + def getAV(self, tau):
122 + edges = self.mcts.root.edges
123 + pi = np.zeros(self.action_size, dtype=np.integer)
124 + values = np.zeros(self.action_size, dtype=np.float32)
125 +
126 + for action, edge in edges:
127 + pi[action] = pow(edge.stats['N'], 1 / tau)
128 + values[action] = edge.stats['Q']
129 +
130 + pi = pi / (np.sum(pi) * 1.0)
131 + return pi, values
132 +
133 + def chooseAction(self, pi, values, tau):
134 + if tau == 0:
135 + actions = np.argwhere(pi == max(pi))
136 + action = random.choice(actions)[0]
137 + else:
138 + action_idx = np.random.multinomial(1, pi)
139 + action = np.where(action_idx == 1)[0][0]
140 + value = values[action]
141 + return action, value
142 +
143 + def replay(self, ltmemory):
144 +
145 + for i in range(config.TRAINING_LOOPS):
146 + minibatch = random.sample(ltmemory, min(config.BATCH_SIZE, len(ltmemory)))
147 +
148 + training_states = np.array([self.model.convertToModelInput(row['state']) for row in minibatch])
149 + training_targets = np.array([row['AV'] for row in minibatch])
150 +
151 + fit = self.model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0,
152 + batch_size=32)
153 + time.sleep(1.0)
154 + print('\n')
155 +
156 + def predict(self, inputToModel):
157 + preds = self.model.predict(inputToModel)
158 + return preds
159 +
160 + def buildMCTS(self, state):
161 + self.root = mc.Node(state)
162 + self.mcts = mc.MCTS(self.root, self.cpuct)
163 +
164 + def changeRootMCTS(self, state):
165 + self.mcts.root = self.mcts.tree[state.id]
166 +
167 + def reset(self):
168 + self.gameState = State([set() for i in range(20)], 0,
169 + np.ones((20, config.NUM_ACTIONS), dtype=np.int),
170 + ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
171 + 'main_summary', 'service_info', 'ico_corp', 'tit_black',
172 + 'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
173 + 'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
174 + self.similarity)
175 + return self.gameState
176 +
177 + def step(self, action):
178 + next_state, value, done = self.gameState.takeAction(action)
179 + self.gameState = next_state
180 + info = None
181 + return next_state, value, done, info
1 +import datetime
2 +import os
3 +from selenium import webdriver
4 +from selenium.common.exceptions import WebDriverException
5 +import config
6 +
7 +
8 +class Browser(object):
9 + def __init__(self):
10 + pass
11 +
12 + def capture(self):
13 + pass
14 +
15 +
16 +class Chrome(Browser):
17 + def __init__(self, html):
18 + self.html = html
19 + self.browser = self._runChrome()
20 +
21 + def _runChrome(self):
22 + options = webdriver.ChromeOptions()
23 + options.add_argument('headless')
24 + options.add_argument('no-sandbox')
25 + driver = webdriver.Chrome(options=options)
26 + driver.set_window_size(config.WIDTH, config.HEIGHT)
27 + return driver
28 +
29 + def capture(self):
30 + """
31 + 현재 상태를 크롬으로 캡쳐한 파일 반환
32 + :return: image_path
33 + """
34 + date_time = datetime.datetime.now()
35 + date_time = date_time.strftime('%Y-%m-%d-%H-%M-%S')
36 + file_path = 'image/' + str(date_time) + '.png'
37 + now_path = 'file://' + os.path.join(os.getcwd(), self.html)
38 + try:
39 + self.browser.get(now_path)
40 + except WebDriverException:
41 + self.browser = self._runChrome()
42 + self.browser.get(now_path)
43 + self.browser.save_screenshot(file_path)
44 + return file_path
1 +#### SELF PLAY
2 +EPISODES = 30
3 +MCTS_SIMS = 50
4 +MEMORY_SIZE = 30000
5 +TURNS_UNTIL_TAU0 = 10 # turn on which it starts playing deterministically
6 +CPUCT = 1
7 +EPSILON = 0.2
8 +ALPHA = 0.8
9 +
10 +INPUT_SHAPE = (20, 21)
11 +
12 +#### RETRAINING
13 +BATCH_SIZE = 256
14 +EPOCHS = 50
15 +REG_CONST = 0.0001
16 +LEARNING_RATE = 0.1
17 +MOMENTUM = 0.9
18 +TRAINING_LOOPS = 10
19 +
20 +HIDDEN_CNN_LAYERS = [
21 + {'filters': 75, 'kernel_size': 4}
22 + , {'filters': 75, 'kernel_size': 4}
23 + , {'filters': 75, 'kernel_size': 4}
24 + , {'filters': 75, 'kernel_size': 4}
25 + , {'filters': 75, 'kernel_size': 4}
26 + , {'filters': 75, 'kernel_size': 4}
27 +]
28 +
29 +HIDDEN_2D_CNN_LAYERS = [
30 + {'filters': 75, 'kernel_size': (4, 4)}
31 + , {'filters': 75, 'kernel_size': (4, 4)}
32 + , {'filters': 75, 'kernel_size': (4, 4)}
33 + , {'filters': 75, 'kernel_size': (4, 4)}
34 + , {'filters': 75, 'kernel_size': (4, 4)}
35 + , {'filters': 75, 'kernel_size': (4, 4)}
36 +]
37 +
38 +#### EVALUATION
39 +EVAL_EPISODES = 20
40 +SCORING_THRESHOLD = 1.3
41 +
42 +ACTION_MEANING = {
43 + 0: 'overflow: hidden; position: absolute; width: 0; height: 0; line-height: 0; text-indent: -9999px',
44 + 1: 'position: absolute; top: 15px; left: 20px; color: #fff; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif',
45 + 2: 'padding-top: 8px; font-size: 14px; line-height: 22px; color: #666; display: -webkit-box; overflow: hidden; max-height: 66px; -webkit-box-orient: vertical; -webkit-line-clamp: 3',
46 + 3: 'display: none',
47 + 4: 'position: relative; width: 1160px; margin: 0 auto; padding-bottom: 180px; width: 100%',
48 + 5: 'display: block; padding-top: 28px; font-size: 12px; clear: both',
49 + 6: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 35px 60px 0; padding-left: 20px; padding-right: 20px',
50 + 7: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box',
51 + 8: 'margin: 22px 0 0 21px; width: 18px; height: 14px; margin-top: 21px; background-position: -40px 0; display: inline-block; background-position: -70px 0; width: 17px; height: 3px; margin: 20px 0 0 4px; background-position: -90px 0',
52 + 9: 'position: absolute; top: 15px; left: 20px; color: #1e1e1e; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif',
53 + 10: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box',
54 + 11: 'display: block; position: relative',
55 + 12: 'width: 100%; margin-bottom: 14px',
56 + 13: 'float: left',
57 + 14: 'position: relative; min-width: 1320px; background-color: #fff',
58 + 15: 'padding-top: 122px; padding-top: 0',
59 + 16: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 76px 60px 0; padding-left: 20px; padding-right: 20px;',
60 + 17: 'font-weight: bold; font-size: 28px; line-height: 45px; padding-left: 2px; color: #003cff; font-size: 32px; font-weight: bold; color: #1e1e1e; display: inline-block; font-size: 20px; vertical-align: top',
61 + 18: 'padding-top: 33px; color: #1e1e1e',
62 + 19: 'text-decoration: underline; display: block; padding-top: 18px; font-size: 20px; line-height: 32px; color: #1e1e1e',
63 +}
64 +WIDTH = 800
65 +HEIGHT = 600
66 +
67 +IMAGE_SIZE = (224, 224)
68 +
69 +LAST_ACTION = len(ACTION_MEANING)
70 +NUM_ACTIONS = len(ACTION_MEANING) + 1
1 +import pickle
2 +import random
3 +
4 +from DNN.model import ResidualCNN
5 +from agent import Agent
6 +import config
7 +import os
8 +
9 +
10 +def playMatches(web, EPISODES, turns_until_tau0, memory=None):
11 + # scores = {player1.name:0, "drawn": 0, player2.name:0}
12 + # sp_scores = {'sp':0, "drawn": 0, 'nsp':0}
13 + # points = {player1.name:[], player2.name:[]}
14 +
15 + for e in range(EPISODES):
16 +
17 + state = web.reset()
18 +
19 + done = 0
20 + web.mcts = None
21 + turn = 0
22 + while done == 0:
23 + turn = turn + 1
24 +
25 + #### Run the MCTS algo and return an action
26 +
27 + action, pi, value = web.act(state, 1)
28 +
29 + if memory is not None:
30 + # Commit the move to memory
31 + memory.commit_stmemory(state, pi)
32 +
33 + # Do the action
34 + state, value, done, _ = web.step(action)
35 +
36 + if done == 1:
37 + if memory is not None:
38 + # If the game is finished, assign the values correctly to the game moves
39 + for move in memory.stmemory:
40 + move['value'] = value
41 +
42 + print('done')
43 + memory.commit_ltmemory()
44 + return memory
45 +
46 +
47 +iteration = 0
48 +cnn = ResidualCNN(config.REG_CONST, config.LEARNING_RATE, config.INPUT_SHAPE, config.NUM_ACTIONS,
49 + config.HIDDEN_CNN_LAYERS)
50 +current_player = Agent('current_player', 20, config.NUM_ACTIONS, config.MCTS_SIMS, config.CPUCT, cnn)
51 +memory = None
52 +run_folder = './run/'
53 +if not os.path.exists(run_folder):
54 + os.mkdir(run_folder)
55 +
56 +while 1:
57 + iteration += 1
58 + print('ITERATION NUMBER ' + str(iteration))
59 + memory = playMatches(current_player, config.EPISODES, turns_until_tau0=config.TURNS_UNTIL_TAU0, memory=memory)
60 + memory.clear_stmemory()
61 +
62 + if len(memory.ltmemory) >= config.MEMORY_SIZE:
63 +
64 + ######## RETRAINING ########
65 + print('RETRAINING...')
66 + current_player.replay(memory.ltmemory)
67 + print('')
68 +
69 + if iteration % 5 == 0:
70 + pickle.dump(memory, open(run_folder + "memory/memory" + str(iteration).zfill(4) + ".p", "wb"))
71 +
72 + memory_samp = random.sample(memory.ltmemory, min(1000, len(memory.ltmemory)))
73 +
74 + else:
75 + print('MEMORY SIZE: ' + str(len(memory.ltmemory)))
1 +import numpy as np
2 +from collections import deque
3 +
4 +import config
5 +
6 +
7 +class Memory:
8 + def __init__(self, memory_size):
9 + self.MEMORY_SIZE = config.MEMORY_SIZE or memory_size
10 + self.ltmemory = deque(maxlen=self.MEMORY_SIZE)
11 + self.stmemory = deque(maxlen=self.MEMORY_SIZE)
12 +
13 + # 학습할 때
14 + # 인덱스
15 + # state_check
16 +
17 + # AV는 label
18 +
19 + def commit_stmemory(self, state, action_value):
20 + self.stmemory.append({
21 + 'state': state.state_check
22 + , 'AV': action_value
23 + })
24 +
25 + def commit_ltmemory(self):
26 + for i in self.stmemory:
27 + self.ltmemory.append(i)
28 + self.clear_stmemory()
29 +
30 + def clear_stmemory(self):
31 + self.stmemory = deque(maxlen=config.MEMORY_SIZE)
1 +import keras
2 +import numpy as np
3 +from keras.applications.vgg16 import preprocess_input, VGG16
4 +from keras.preprocessing import image
5 +from numpy.linalg import norm
6 +
7 +keras.backend.set_image_data_format('channels_first')
8 +
9 +def preprocess_image(img_path):
10 + img = image.load_img(img_path, target_size=(224, 224))
11 + img_data = image.img_to_array(img, data_format='channels_first')
12 + img_data = np.expand_dims(img_data, axis=0)
13 + img_data = preprocess_input(img_data)
14 + return img_data
15 +
16 +
17 +class Similarity(object):
18 + def __init__(self, dest_file_path):
19 + self.model = VGG16(weights='imagenet', include_top=False)
20 + self.dest = preprocess_image(dest_file_path)
21 +
22 + def _feature(self, image):
23 + image_features = self.model.predict(image)
24 + image_features = image_features.reshape(7 * 7 * 512)
25 + return image_features
26 +
27 + def similarity(self, before_image):
28 + """
29 + capture를 통해 얻은 이미지와 목표 이미지간 유사도를 계산
30 + :return: 현재 이미지와 목표 이미지 유사도
31 + """
32 + before = self._feature(before_image)
33 + after = self._feature(self.dest)
34 + return np.dot(before, after) / (norm(before) * norm(after))
35 +# image path
36 +# image..
1 +import base64
2 +import config
3 +import numpy as np
4 +import copy
5 +
6 +from similarity import preprocess_image
7 +
8 +
9 +class State(object):
10 + def __init__(self, board, index, state_check, class_names, browser, sim_func):
11 + self.class_names = class_names # 클래스 이름들
12 + self.board = board # 현재 상황 (set list)
13 + self.index = index # 현재 위치
14 + self.id = self._convertStateToId() # state를 식별할 id
15 + self.state_check = state_check
16 + self.allowedActions = self._available_action()
17 + self.browser = browser
18 + self._make_css(state_check)
19 + self.sim_func = sim_func
20 + self.now_image = preprocess_image(self.browser.capture())
21 + self.similarity = sim_func.similarity(self.now_image)
22 + # np.ones((len(self.board), config.NUM_ACTIONS), dtype=np.int) # 실제 확인
23 +
24 + def _available_action(self):
25 + return np.nonzero(self.state_check[self.index])[0]
26 +
27 + def _convertStateToId(self):
28 + id = '{}'.format(self.index).join(map(lambda f: str(f) if len(f) != 0 else '{}', self.board))
29 + b64 = base64.b64encode(id.encode('utf-8'))
30 + return str(b64)
31 +
32 + def _checkForEndGame(self):
33 + if self.index >= len(self.board) - 2:
34 + return True
35 + return False
36 +
37 + def takeAction(self, act):
38 + """
39 + 액션대로 상태를 변경
40 + :param act:
41 + :returns newState:
42 + value:
43 + done:
44 + """
45 + if act >= config.NUM_ACTIONS:
46 + raise IndexError
47 +
48 + newBoard = copy.deepcopy(self.board)
49 + newStateCheck = copy.deepcopy(self.state_check)
50 + value = 0
51 + done = 0
52 + index = self.index
53 + if act == config.LAST_ACTION:
54 + index += 1
55 + else:
56 + newBoard[index].add(act)
57 + newStateCheck[index][act] = 0
58 + newState = State(newBoard, index, newStateCheck, self.class_names, self.browser, self.sim_func)
59 + if newState._checkForEndGame():
60 + print('done')
61 + done = 1
62 + if newState.similarity >= 0.8:
63 + done = 1
64 + value = 1
65 + print('good')
66 +
67 + return newState, value, done
68 +
69 + def _make_css(self, state_check):
70 + """
71 + state_check로 css 작성
72 + :return: file_name
73 + """
74 + file_name = 'main.css'
75 + with open(file_name, 'w') as f:
76 + for class_index, actions in enumerate(state_check):
77 + class_index = int(class_index)
78 + f.write(".{}".format(self.class_names[class_index]))
79 + f.write("{")
80 + for action_index, action in enumerate(actions):
81 + if action == 0:
82 + f.write(" {}; ".format(config.ACTION_MEANING[action_index]))
83 + f.write("}\n")
84 +
85 + return file_name
86 +
87 + def reset(self):
88 + """
89 + 상태 초기화
90 + :return:
91 + """
92 + self.index = 0
93 + self.state_check = np.ones((len(self.board), config.NUM_ACTIONS))