Showing
19 changed files
with
968 additions
and
0 deletions
document/2014104150_조성아_캡스톤디자인2_중간보고서.pdf
0 → 100644
No preview for this file type
document/4월 면담서.pdf
0 → 100644
No preview for this file type
document/면담확인서_5월.pdf
0 → 100644
No preview for this file type
document/캡스톤디자인2_주제보고서.docx
0 → 100644
This file is too large to display.
project/.gitignore
0 → 100644
1 | +# Byte-compiled / optimized / DLL files | ||
2 | +__pycache__/ | ||
3 | +*.py[cod] | ||
4 | +*$py.class | ||
5 | +.idea | ||
6 | + | ||
7 | +# C extensions | ||
8 | +*.so | ||
9 | + | ||
10 | +# Distribution / packaging | ||
11 | +.Python | ||
12 | +build/ | ||
13 | +develop-eggs/ | ||
14 | +dist/ | ||
15 | +downloads/ | ||
16 | +eggs/ | ||
17 | +.eggs/ | ||
18 | +lib/ | ||
19 | +lib64/ | ||
20 | +parts/ | ||
21 | +sdist/ | ||
22 | +var/ | ||
23 | +wheels/ | ||
24 | +*.egg-info/ | ||
25 | +.installed.cfg | ||
26 | +*.egg | ||
27 | +MANIFEST | ||
28 | + | ||
29 | +# PyInstaller | ||
30 | +# Usually these files are written by a python script from a template | ||
31 | +# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
32 | +*.manifest | ||
33 | +*.spec | ||
34 | + | ||
35 | +# Installer logs | ||
36 | +pip-log.txt | ||
37 | +pip-delete-this-directory.txt | ||
38 | + | ||
39 | +# Unit test / coverage reports | ||
40 | +htmlcov/ | ||
41 | +.tox/ | ||
42 | +.coverage | ||
43 | +.coverage.* | ||
44 | +.cache | ||
45 | +nosetests.xml | ||
46 | +coverage.xml | ||
47 | +*.cover | ||
48 | +.hypothesis/ | ||
49 | +.pytest_cache/ | ||
50 | + | ||
51 | +# Translations | ||
52 | +*.mo | ||
53 | +*.pot | ||
54 | + | ||
55 | +# Django stuff: | ||
56 | +*.log | ||
57 | +local_settings.py | ||
58 | +db.sqlite3 | ||
59 | + | ||
60 | +# Flask stuff: | ||
61 | +instance/ | ||
62 | +.webassets-cache | ||
63 | + | ||
64 | +# Scrapy stuff: | ||
65 | +.scrapy | ||
66 | + | ||
67 | +# Sphinx documentation | ||
68 | +docs/_build/ | ||
69 | + | ||
70 | +# PyBuilder | ||
71 | +target/ | ||
72 | + | ||
73 | +# Jupyter Notebook | ||
74 | +.ipynb_checkpoints | ||
75 | + | ||
76 | +# pyenv | ||
77 | +.python-version | ||
78 | + | ||
79 | +# celery beat schedule file | ||
80 | +celerybeat-schedule | ||
81 | + | ||
82 | +# SageMath parsed files | ||
83 | +*.sage.py | ||
84 | + | ||
85 | +# Environments | ||
86 | +.env | ||
87 | +.venv | ||
88 | +env/ | ||
89 | +venv/ | ||
90 | +ENV/ | ||
91 | +env.bak/ | ||
92 | +venv.bak/ | ||
93 | + | ||
94 | +# Spyder project settings | ||
95 | +.spyderproject | ||
96 | +.spyproject | ||
97 | + | ||
98 | +# Rope project settings | ||
99 | +.ropeproject | ||
100 | + | ||
101 | +# mkdocs documentation | ||
102 | +/site | ||
103 | + | ||
104 | +# mypy | ||
105 | +.mypy_cache/ |
project/DNN/__init__.py
0 → 100644
File mode changed
project/DNN/loss.py
0 → 100644
1 | + | ||
2 | +import tensorflow as tf | ||
3 | + | ||
4 | +def softmax_cross_entropy_with_logits(y_true, y_pred): | ||
5 | + | ||
6 | + p = y_pred | ||
7 | + pi = y_true | ||
8 | + | ||
9 | + zero = tf.zeros(shape = tf.shape(pi), dtype=tf.float32) | ||
10 | + where = tf.equal(pi, zero) | ||
11 | + | ||
12 | + negatives = tf.fill(tf.shape(pi), -100.0) | ||
13 | + p = tf.where(where, negatives, p) | ||
14 | + | ||
15 | + loss = tf.nn.softmax_cross_entropy_with_logits(labels = pi, logits = p) | ||
16 | + | ||
17 | + return loss | ||
18 | + | ||
19 | + | ||
20 | + | ||
21 | + | ||
22 | + | ||
23 | + |
project/DNN/model.py
0 → 100644
1 | +import config | ||
2 | +import numpy as np | ||
3 | + | ||
4 | +from keras.models import load_model, Model | ||
5 | +from keras.layers import Input, Dense, Conv1D, Flatten, BatchNormalization, LeakyReLU, add | ||
6 | +from keras.optimizers import SGD | ||
7 | +from keras import regularizers | ||
8 | + | ||
9 | +from DNN.loss import softmax_cross_entropy_with_logits | ||
10 | + | ||
11 | + | ||
12 | +class GeneralModel(object): | ||
13 | + def __init__(self, reg_const, learning_rate, input_dim, output_dim): | ||
14 | + self.reg_const = reg_const | ||
15 | + self.learning_rate = learning_rate | ||
16 | + self.input_dim = input_dim | ||
17 | + self.output_dim = output_dim | ||
18 | + | ||
19 | + def predict(self, x): | ||
20 | + return self.model.predict(x) | ||
21 | + | ||
22 | + def fit(self, states, targets, epochs, verbose, validation_split, batch_size): | ||
23 | + return self.model.fit(states, targets, epochs=epochs, verbose=verbose, validation_split=validation_split, | ||
24 | + batch_size=batch_size) | ||
25 | + | ||
26 | + def write(self, game, version): | ||
27 | + self.model.save('models/version' + "{0:0>4}".format(version) + '.h5') | ||
28 | + | ||
29 | + def read(self, game, run_number, version): | ||
30 | + return load_model(str(run_number).zfill(4) + "/models/version" + "{0:0>4}".format | ||
31 | + (version) + '.h5', custom_objects={'softmax_cross_entropy_with_logits': softmax_cross_entropy_with_logits}) | ||
32 | + | ||
33 | + | ||
34 | +class ResidualCNN(GeneralModel): | ||
35 | + def __init__(self, reg_const, learning_rate, input_dim, output_dim, hidden_layers): | ||
36 | + GeneralModel.__init__(self, reg_const, learning_rate, input_dim, output_dim) | ||
37 | + self.hidden_layers = hidden_layers | ||
38 | + self.num_layers = len(hidden_layers) | ||
39 | + self.model = self._build_model() | ||
40 | + | ||
41 | + def residual_layer(self, input_block, filters, kernel_size): | ||
42 | + | ||
43 | + x = self.conv_layer(input_block, filters, kernel_size) | ||
44 | + | ||
45 | + x = Conv1D( | ||
46 | + filters=filters | ||
47 | + , kernel_size=kernel_size | ||
48 | + , data_format="channels_last" | ||
49 | + , padding='same' | ||
50 | + , use_bias=False | ||
51 | + , activation='linear' | ||
52 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
53 | + )(x) | ||
54 | + | ||
55 | + x = BatchNormalization(axis=1)(x) | ||
56 | + | ||
57 | + x = add([input_block, x]) | ||
58 | + | ||
59 | + x = LeakyReLU()(x) | ||
60 | + | ||
61 | + return x | ||
62 | + | ||
63 | + def conv_layer(self, x, filters, kernel_size): | ||
64 | + | ||
65 | + x = Conv1D( | ||
66 | + filters=filters | ||
67 | + , kernel_size=kernel_size | ||
68 | + , data_format="channels_last" | ||
69 | + , padding='same' | ||
70 | + , use_bias=False | ||
71 | + , activation='linear' | ||
72 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
73 | + )(x) | ||
74 | + | ||
75 | + x = BatchNormalization(axis=1)(x) | ||
76 | + x = LeakyReLU()(x) | ||
77 | + | ||
78 | + return x | ||
79 | + | ||
80 | + def value_head(self, x): | ||
81 | + x = Conv1D( | ||
82 | + filters=2 | ||
83 | + , kernel_size=1 | ||
84 | + , data_format="channels_last" | ||
85 | + , padding='same' | ||
86 | + , use_bias=False | ||
87 | + , activation='linear' | ||
88 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
89 | + )(x) | ||
90 | + | ||
91 | + x = BatchNormalization(axis=1)(x) | ||
92 | + x = LeakyReLU()(x) | ||
93 | + | ||
94 | + x = Flatten()(x) | ||
95 | + | ||
96 | + x = Dense( | ||
97 | + self.output_dim | ||
98 | + , use_bias=False | ||
99 | + , activation='linear' | ||
100 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
101 | + , name='value_head' | ||
102 | + )(x) | ||
103 | + | ||
104 | + return x | ||
105 | + | ||
106 | + def policy_head(self, x): | ||
107 | + x = Conv1D( | ||
108 | + filters=2 | ||
109 | + , kernel_size=1 | ||
110 | + , data_format="channels_last" | ||
111 | + , padding='same' | ||
112 | + , use_bias=False | ||
113 | + , activation='linear' | ||
114 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
115 | + )(x) | ||
116 | + | ||
117 | + x = BatchNormalization(axis=1)(x) | ||
118 | + x = LeakyReLU()(x) | ||
119 | + | ||
120 | + x = Flatten()(x) | ||
121 | + | ||
122 | + x = Dense( | ||
123 | + self.output_dim | ||
124 | + , use_bias=False | ||
125 | + , activation='linear' | ||
126 | + , kernel_regularizer=regularizers.l2(self.reg_const) | ||
127 | + , name='policy_head' | ||
128 | + )(x) | ||
129 | + | ||
130 | + return x | ||
131 | + | ||
132 | + def _build_model(self): | ||
133 | + # image shape | ||
134 | + # 그냥 배열 shape | ||
135 | + | ||
136 | + main_input = Input(shape=self.input_dim, name='main_input') | ||
137 | + | ||
138 | + x = self.conv_layer(main_input, self.hidden_layers[0]['filters'], self.hidden_layers[0]['kernel_size']) | ||
139 | + | ||
140 | + if len(self.hidden_layers) > 1: | ||
141 | + for h in self.hidden_layers[1:]: | ||
142 | + x = self.residual_layer(x, h['filters'], h['kernel_size']) | ||
143 | + | ||
144 | + vh = self.value_head(x) | ||
145 | + ph = self.policy_head(x) | ||
146 | + | ||
147 | + model = Model(inputs=[main_input], outputs=[ph]) | ||
148 | + model.compile(loss=softmax_cross_entropy_with_logits, | ||
149 | + optimizer=SGD(lr=self.learning_rate, momentum=config.MOMENTUM), | ||
150 | + loss_weights=[0.5] | ||
151 | + ) | ||
152 | + | ||
153 | + return model | ||
154 | + | ||
155 | + def convertToModelInput(self, state): | ||
156 | + # [20,20] | ||
157 | + inputToModel = state.state_check # np.append(state.binary, [(state.playerTurn + 1)/2] * self.input_dim[1] * self.input_dim[2]) | ||
158 | + inputToModel = np.reshape(inputToModel, self.input_dim) | ||
159 | + return inputToModel |
project/LICENSE
0 → 100644
1 | +MIT License | ||
2 | + | ||
3 | +Copyright (c) 2019 seongahjo | ||
4 | + | ||
5 | +Permission is hereby granted, free of charge, to any person obtaining a copy | ||
6 | +of this software and associated documentation files (the "Software"), to deal | ||
7 | +in the Software without restriction, including without limitation the rights | ||
8 | +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
9 | +copies of the Software, and to permit persons to whom the Software is | ||
10 | +furnished to do so, subject to the following conditions: | ||
11 | + | ||
12 | +The above copyright notice and this permission notice shall be included in all | ||
13 | +copies or substantial portions of the Software. | ||
14 | + | ||
15 | +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
16 | +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
17 | +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
18 | +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
19 | +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
20 | +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
21 | +SOFTWARE. |
project/MCTS.py
0 → 100644
1 | +import numpy as np | ||
2 | +import logging | ||
3 | +import config | ||
4 | + | ||
5 | + | ||
6 | +class Node(object): | ||
7 | + | ||
8 | + def __init__(self, state): | ||
9 | + self.state = state | ||
10 | + self.id = state.id | ||
11 | + self.edges = [] | ||
12 | + | ||
13 | + def is_leaf(self): | ||
14 | + if len(self.edges) > 0: | ||
15 | + return False | ||
16 | + else: | ||
17 | + return True | ||
18 | + | ||
19 | + | ||
20 | +class Edge(object): | ||
21 | + | ||
22 | + def __init__(self, in_node, out_node, prior, action): | ||
23 | + self.id = in_node.state.id + '|' + out_node.state.id | ||
24 | + self.inNode = in_node | ||
25 | + self.outNode = out_node | ||
26 | + self.action = action | ||
27 | + | ||
28 | + self.stats = { | ||
29 | + 'N': 0, | ||
30 | + 'W': 0, | ||
31 | + 'Q': 0, | ||
32 | + 'P': prior, | ||
33 | + } | ||
34 | + | ||
35 | + | ||
36 | +# N 이 횟수 | ||
37 | +# W 가 win | ||
38 | +# Q 가 승률 | ||
39 | +# P 가 | ||
40 | + | ||
41 | +class MCTS(object): | ||
42 | + | ||
43 | + def __init__(self, root, cpuct): | ||
44 | + self.root = root | ||
45 | + self.tree = {} | ||
46 | + self.cpuct = cpuct | ||
47 | + self.add_node(root) | ||
48 | + | ||
49 | + def __len__(self): | ||
50 | + return len(self.tree) | ||
51 | + | ||
52 | + def moveToLeaf(self): | ||
53 | + | ||
54 | + breadcrumbs = [] | ||
55 | + current_node = self.root | ||
56 | + | ||
57 | + done = 0 | ||
58 | + value = 0 | ||
59 | + | ||
60 | + while not current_node.is_leaf(): | ||
61 | + | ||
62 | + max_qu = -99999 | ||
63 | + | ||
64 | + if current_node == self.root: | ||
65 | + epsilon = config.EPSILON | ||
66 | + nu = np.random.dirichlet([config.ALPHA] * len(current_node.edges)) | ||
67 | + else: | ||
68 | + epsilon = 0 | ||
69 | + nu = [0] * len(current_node.edges) | ||
70 | + | ||
71 | + Nb = 0 | ||
72 | + for action, edge in current_node.edges: | ||
73 | + Nb = Nb + edge.stats['N'] | ||
74 | + | ||
75 | + for idx, (action, edge) in enumerate(current_node.edges): | ||
76 | + | ||
77 | + U = self.cpuct * \ | ||
78 | + ((1 - epsilon) * edge.stats['P'] + epsilon * nu[idx]) * \ | ||
79 | + np.sqrt(Nb) / (1 + edge.stats['N']) | ||
80 | + | ||
81 | + Q = edge.stats['Q'] | ||
82 | + | ||
83 | + if Q + U > max_qu: | ||
84 | + max_qu = Q + U | ||
85 | + simulation_action = action | ||
86 | + simulation_edge = edge | ||
87 | + | ||
88 | + new_state, value, done = current_node.state.takeAction(simulation_action) | ||
89 | + current_node = simulation_edge.outNode | ||
90 | + breadcrumbs.append(simulation_edge) | ||
91 | + | ||
92 | + return current_node, value, done, breadcrumbs | ||
93 | + | ||
94 | + def back_fill(self, value, breadcrumbs): | ||
95 | + for edge in breadcrumbs: | ||
96 | + edge.stats['N'] = edge.stats['N'] + 1 | ||
97 | + edge.stats['W'] = edge.stats['W'] + value | ||
98 | + edge.stats['Q'] = edge.stats['W'] / edge.stats['N'] | ||
99 | + | ||
100 | + def add_node(self, node): | ||
101 | + self.tree[node.id] = node |
project/README.md
0 → 100644
1 | +## D.sAIgn | ||
2 | +Web Design Automation from screenshot by Reinforcement Learning | ||
3 | + | ||
4 | +**Now only train is availble** | ||
5 | +## Feature | ||
6 | +* Monte Carlo Tree Search | ||
7 | +* DQN | ||
8 | + | ||
9 | +## Installation | ||
10 | +* `Python 3` | ||
11 | + | ||
12 | + | ||
13 | +## Usage | ||
14 | +* configure `config.py` | ||
15 | +* `python main.py` | ||
16 | + | ||
17 | +## Roadmap | ||
18 | +- [ ] run D.sAIgn using trained model | ||
19 | +- [ ] see remarkable result | ||
20 | +- [ ] fast train by multiprocessing | ||
21 | +- [ ] find majorly used css properties | ||
22 | +- [ ] offer pretrained model | ||
23 | + | ||
24 | +## License | ||
25 | +D.sAIgn is released under [MIT License] | ||
26 | + | ||
27 | +[MIT License]:https://github.com/seongahjo/D.sAIgn/blob/master/LICENSE | ||
28 | + | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
project/__init__.py
0 → 100644
1 | +__all__ = ['config'] | ||
... | \ No newline at end of file | ... | \ No newline at end of file |
project/agent.py
0 → 100644
1 | +# %matplotlib inline | ||
2 | + | ||
3 | +import numpy as np | ||
4 | +import random | ||
5 | +import config | ||
6 | +import MCTS as mc | ||
7 | + | ||
8 | +import time | ||
9 | + | ||
10 | +from browser import Chrome | ||
11 | +from similarity import Similarity | ||
12 | +from state import State | ||
13 | + | ||
14 | + | ||
15 | +class Agent(): | ||
16 | + def __init__(self, name, state_size, action_size, mcts_simulations, cpuct, model=None): | ||
17 | + self.name = name | ||
18 | + self.action_size = action_size | ||
19 | + | ||
20 | + self.cpuct = cpuct | ||
21 | + self.similarity = Similarity('result.png') | ||
22 | + self.MCTSsimulations = mcts_simulations | ||
23 | + self.model = model | ||
24 | + self.browser = Chrome('kakao.html') | ||
25 | + self.mcts = None | ||
26 | + self.gameState = State([set() for i in range(20)], | ||
27 | + 0, | ||
28 | + np.ones((20, config.NUM_ACTIONS), dtype=np.int), | ||
29 | + ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date', | ||
30 | + 'main_summary', 'service_info', 'ico_corp', 'tit_black', | ||
31 | + 'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent', | ||
32 | + 'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser, | ||
33 | + self.similarity) | ||
34 | + | ||
35 | + self.state_size = 20 | ||
36 | + self.train_overall_loss = [] | ||
37 | + self.train_value_loss = [] | ||
38 | + self.train_policy_loss = [] | ||
39 | + self.val_overall_loss = [] | ||
40 | + self.val_value_loss = [] | ||
41 | + self.val_policy_loss = [] | ||
42 | + | ||
43 | + def simulate(self): | ||
44 | + | ||
45 | + ##### MOVE THE LEAF NODE | ||
46 | + leaf, value, done, breadcrumbs = self.mcts.moveToLeaf() | ||
47 | + ##### EVALUATE THE LEAF NODE | ||
48 | + value, breadcrumbs = self.evaluateLeaf(leaf, value, done, breadcrumbs) | ||
49 | + ##### BACKFILL THE VALUE THROUGH THE TREE | ||
50 | + self.mcts.back_fill(value, breadcrumbs) | ||
51 | + | ||
52 | + def act(self, state, tau): | ||
53 | + | ||
54 | + if self.mcts is None or state.id not in self.mcts.tree: | ||
55 | + self.buildMCTS(state) | ||
56 | + else: | ||
57 | + self.changeRootMCTS(state) | ||
58 | + | ||
59 | + #### run the simulation | ||
60 | + for sim in range(self.MCTSsimulations): | ||
61 | + self.simulate() | ||
62 | + | ||
63 | + #### get action values | ||
64 | + pi, values = self.getAV(1) | ||
65 | + | ||
66 | + ####pick the action | ||
67 | + action, value = self.chooseAction(pi, values, tau) | ||
68 | + print('action {} '.format(action)) | ||
69 | + nextState, _, _ = state.takeAction(action) | ||
70 | + | ||
71 | + return action, pi, value | ||
72 | + | ||
73 | + def get_preds(self, state): | ||
74 | + # predict the leaf | ||
75 | + inputToModel = np.array([self.model.convertToModelInput(state)]) | ||
76 | + preds = self.model.predict(inputToModel) | ||
77 | + logits_array = preds | ||
78 | + logits = logits_array[0] | ||
79 | + allowedActions = state.allowedActions | ||
80 | + | ||
81 | + mask = np.ones(config.NUM_ACTIONS, dtype=bool) | ||
82 | + mask[allowedActions] = False | ||
83 | + | ||
84 | + # sim = np.zeros(config.NUM_ACTIONS, dtype=np.float) | ||
85 | + # for act in allowedActions: | ||
86 | + # new_state, _, _ = state.takeAction(act) | ||
87 | + # sim[act] += new_state.similarity * 100 | ||
88 | + # logits = sim | ||
89 | + # print('sim : {} '.format(logits)) | ||
90 | + # print('probs : {} '.format(probs)) | ||
91 | + # probs = np.ones(config.NUM_ACTIONS) | ||
92 | + | ||
93 | + logits[mask] = -100 | ||
94 | + | ||
95 | + # SOFTMAX | ||
96 | + odds = np.exp(logits) | ||
97 | + probs = odds / np.sum(odds) ###put this just before the for? | ||
98 | + | ||
99 | + return probs, allowedActions | ||
100 | + | ||
101 | + def evaluateLeaf(self, leaf, value, done, breadcrumbs): | ||
102 | + | ||
103 | + if done == 0: | ||
104 | + | ||
105 | + probs, allowedActions = self.get_preds(leaf.state) | ||
106 | + probs = probs[allowedActions] | ||
107 | + | ||
108 | + for idx, act in enumerate(allowedActions): | ||
109 | + newState, _, _ = leaf.state.takeAction(act) | ||
110 | + if newState.id not in self.mcts.tree: | ||
111 | + node = mc.Node(newState) | ||
112 | + self.mcts.add_node(node) | ||
113 | + else: | ||
114 | + node = self.mcts.tree[newState.id] | ||
115 | + | ||
116 | + newEdge = mc.Edge(leaf, node, probs[idx], act) | ||
117 | + leaf.edges.append((act, newEdge)) | ||
118 | + | ||
119 | + return (value, breadcrumbs) | ||
120 | + | ||
121 | + def getAV(self, tau): | ||
122 | + edges = self.mcts.root.edges | ||
123 | + pi = np.zeros(self.action_size, dtype=np.integer) | ||
124 | + values = np.zeros(self.action_size, dtype=np.float32) | ||
125 | + | ||
126 | + for action, edge in edges: | ||
127 | + pi[action] = pow(edge.stats['N'], 1 / tau) | ||
128 | + values[action] = edge.stats['Q'] | ||
129 | + | ||
130 | + pi = pi / (np.sum(pi) * 1.0) | ||
131 | + return pi, values | ||
132 | + | ||
133 | + def chooseAction(self, pi, values, tau): | ||
134 | + if tau == 0: | ||
135 | + actions = np.argwhere(pi == max(pi)) | ||
136 | + action = random.choice(actions)[0] | ||
137 | + else: | ||
138 | + action_idx = np.random.multinomial(1, pi) | ||
139 | + action = np.where(action_idx == 1)[0][0] | ||
140 | + value = values[action] | ||
141 | + return action, value | ||
142 | + | ||
143 | + def replay(self, ltmemory): | ||
144 | + | ||
145 | + for i in range(config.TRAINING_LOOPS): | ||
146 | + minibatch = random.sample(ltmemory, min(config.BATCH_SIZE, len(ltmemory))) | ||
147 | + | ||
148 | + training_states = np.array([self.model.convertToModelInput(row['state']) for row in minibatch]) | ||
149 | + training_targets = np.array([row['AV'] for row in minibatch]) | ||
150 | + | ||
151 | + fit = self.model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0, | ||
152 | + batch_size=32) | ||
153 | + time.sleep(1.0) | ||
154 | + print('\n') | ||
155 | + | ||
156 | + def predict(self, inputToModel): | ||
157 | + preds = self.model.predict(inputToModel) | ||
158 | + return preds | ||
159 | + | ||
160 | + def buildMCTS(self, state): | ||
161 | + self.root = mc.Node(state) | ||
162 | + self.mcts = mc.MCTS(self.root, self.cpuct) | ||
163 | + | ||
164 | + def changeRootMCTS(self, state): | ||
165 | + self.mcts.root = self.mcts.tree[state.id] | ||
166 | + | ||
167 | + def reset(self): | ||
168 | + self.gameState = State([set() for i in range(20)], 0, | ||
169 | + np.ones((20, config.NUM_ACTIONS), dtype=np.int), | ||
170 | + ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date', | ||
171 | + 'main_summary', 'service_info', 'ico_corp', 'tit_black', | ||
172 | + 'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent', | ||
173 | + 'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser, | ||
174 | + self.similarity) | ||
175 | + return self.gameState | ||
176 | + | ||
177 | + def step(self, action): | ||
178 | + next_state, value, done = self.gameState.takeAction(action) | ||
179 | + self.gameState = next_state | ||
180 | + info = None | ||
181 | + return next_state, value, done, info |
project/browser.py
0 → 100644
1 | +import datetime | ||
2 | +import os | ||
3 | +from selenium import webdriver | ||
4 | +from selenium.common.exceptions import WebDriverException | ||
5 | +import config | ||
6 | + | ||
7 | + | ||
8 | +class Browser(object): | ||
9 | + def __init__(self): | ||
10 | + pass | ||
11 | + | ||
12 | + def capture(self): | ||
13 | + pass | ||
14 | + | ||
15 | + | ||
16 | +class Chrome(Browser): | ||
17 | + def __init__(self, html): | ||
18 | + self.html = html | ||
19 | + self.browser = self._runChrome() | ||
20 | + | ||
21 | + def _runChrome(self): | ||
22 | + options = webdriver.ChromeOptions() | ||
23 | + options.add_argument('headless') | ||
24 | + options.add_argument('no-sandbox') | ||
25 | + driver = webdriver.Chrome(options=options) | ||
26 | + driver.set_window_size(config.WIDTH, config.HEIGHT) | ||
27 | + return driver | ||
28 | + | ||
29 | + def capture(self): | ||
30 | + """ | ||
31 | + 현재 상태를 크롬으로 캡쳐한 파일 반환 | ||
32 | + :return: image_path | ||
33 | + """ | ||
34 | + date_time = datetime.datetime.now() | ||
35 | + date_time = date_time.strftime('%Y-%m-%d-%H-%M-%S') | ||
36 | + file_path = 'image/' + str(date_time) + '.png' | ||
37 | + now_path = 'file://' + os.path.join(os.getcwd(), self.html) | ||
38 | + try: | ||
39 | + self.browser.get(now_path) | ||
40 | + except WebDriverException: | ||
41 | + self.browser = self._runChrome() | ||
42 | + self.browser.get(now_path) | ||
43 | + self.browser.save_screenshot(file_path) | ||
44 | + return file_path |
project/config.py
0 → 100644
1 | +#### SELF PLAY | ||
2 | +EPISODES = 30 | ||
3 | +MCTS_SIMS = 50 | ||
4 | +MEMORY_SIZE = 30000 | ||
5 | +TURNS_UNTIL_TAU0 = 10 # turn on which it starts playing deterministically | ||
6 | +CPUCT = 1 | ||
7 | +EPSILON = 0.2 | ||
8 | +ALPHA = 0.8 | ||
9 | + | ||
10 | +INPUT_SHAPE = (20, 21) | ||
11 | + | ||
12 | +#### RETRAINING | ||
13 | +BATCH_SIZE = 256 | ||
14 | +EPOCHS = 50 | ||
15 | +REG_CONST = 0.0001 | ||
16 | +LEARNING_RATE = 0.1 | ||
17 | +MOMENTUM = 0.9 | ||
18 | +TRAINING_LOOPS = 10 | ||
19 | + | ||
20 | +HIDDEN_CNN_LAYERS = [ | ||
21 | + {'filters': 75, 'kernel_size': 4} | ||
22 | + , {'filters': 75, 'kernel_size': 4} | ||
23 | + , {'filters': 75, 'kernel_size': 4} | ||
24 | + , {'filters': 75, 'kernel_size': 4} | ||
25 | + , {'filters': 75, 'kernel_size': 4} | ||
26 | + , {'filters': 75, 'kernel_size': 4} | ||
27 | +] | ||
28 | + | ||
29 | +HIDDEN_2D_CNN_LAYERS = [ | ||
30 | + {'filters': 75, 'kernel_size': (4, 4)} | ||
31 | + , {'filters': 75, 'kernel_size': (4, 4)} | ||
32 | + , {'filters': 75, 'kernel_size': (4, 4)} | ||
33 | + , {'filters': 75, 'kernel_size': (4, 4)} | ||
34 | + , {'filters': 75, 'kernel_size': (4, 4)} | ||
35 | + , {'filters': 75, 'kernel_size': (4, 4)} | ||
36 | +] | ||
37 | + | ||
38 | +#### EVALUATION | ||
39 | +EVAL_EPISODES = 20 | ||
40 | +SCORING_THRESHOLD = 1.3 | ||
41 | + | ||
42 | +ACTION_MEANING = { | ||
43 | + 0: 'overflow: hidden; position: absolute; width: 0; height: 0; line-height: 0; text-indent: -9999px', | ||
44 | + 1: 'position: absolute; top: 15px; left: 20px; color: #fff; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif', | ||
45 | + 2: 'padding-top: 8px; font-size: 14px; line-height: 22px; color: #666; display: -webkit-box; overflow: hidden; max-height: 66px; -webkit-box-orient: vertical; -webkit-line-clamp: 3', | ||
46 | + 3: 'display: none', | ||
47 | + 4: 'position: relative; width: 1160px; margin: 0 auto; padding-bottom: 180px; width: 100%', | ||
48 | + 5: 'display: block; padding-top: 28px; font-size: 12px; clear: both', | ||
49 | + 6: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 35px 60px 0; padding-left: 20px; padding-right: 20px', | ||
50 | + 7: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box', | ||
51 | + 8: 'margin: 22px 0 0 21px; width: 18px; height: 14px; margin-top: 21px; background-position: -40px 0; display: inline-block; background-position: -70px 0; width: 17px; height: 3px; margin: 20px 0 0 4px; background-position: -90px 0', | ||
52 | + 9: 'position: absolute; top: 15px; left: 20px; color: #1e1e1e; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif', | ||
53 | + 10: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box', | ||
54 | + 11: 'display: block; position: relative', | ||
55 | + 12: 'width: 100%; margin-bottom: 14px', | ||
56 | + 13: 'float: left', | ||
57 | + 14: 'position: relative; min-width: 1320px; background-color: #fff', | ||
58 | + 15: 'padding-top: 122px; padding-top: 0', | ||
59 | + 16: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 76px 60px 0; padding-left: 20px; padding-right: 20px;', | ||
60 | + 17: 'font-weight: bold; font-size: 28px; line-height: 45px; padding-left: 2px; color: #003cff; font-size: 32px; font-weight: bold; color: #1e1e1e; display: inline-block; font-size: 20px; vertical-align: top', | ||
61 | + 18: 'padding-top: 33px; color: #1e1e1e', | ||
62 | + 19: 'text-decoration: underline; display: block; padding-top: 18px; font-size: 20px; line-height: 32px; color: #1e1e1e', | ||
63 | +} | ||
64 | +WIDTH = 800 | ||
65 | +HEIGHT = 600 | ||
66 | + | ||
67 | +IMAGE_SIZE = (224, 224) | ||
68 | + | ||
69 | +LAST_ACTION = len(ACTION_MEANING) | ||
70 | +NUM_ACTIONS = len(ACTION_MEANING) + 1 |
project/main.py
0 → 100644
1 | +import pickle | ||
2 | +import random | ||
3 | + | ||
4 | +from DNN.model import ResidualCNN | ||
5 | +from agent import Agent | ||
6 | +import config | ||
7 | +import os | ||
8 | + | ||
9 | + | ||
10 | +def playMatches(web, EPISODES, turns_until_tau0, memory=None): | ||
11 | + # scores = {player1.name:0, "drawn": 0, player2.name:0} | ||
12 | + # sp_scores = {'sp':0, "drawn": 0, 'nsp':0} | ||
13 | + # points = {player1.name:[], player2.name:[]} | ||
14 | + | ||
15 | + for e in range(EPISODES): | ||
16 | + | ||
17 | + state = web.reset() | ||
18 | + | ||
19 | + done = 0 | ||
20 | + web.mcts = None | ||
21 | + turn = 0 | ||
22 | + while done == 0: | ||
23 | + turn = turn + 1 | ||
24 | + | ||
25 | + #### Run the MCTS algo and return an action | ||
26 | + | ||
27 | + action, pi, value = web.act(state, 1) | ||
28 | + | ||
29 | + if memory is not None: | ||
30 | + # Commit the move to memory | ||
31 | + memory.commit_stmemory(state, pi) | ||
32 | + | ||
33 | + # Do the action | ||
34 | + state, value, done, _ = web.step(action) | ||
35 | + | ||
36 | + if done == 1: | ||
37 | + if memory is not None: | ||
38 | + # If the game is finished, assign the values correctly to the game moves | ||
39 | + for move in memory.stmemory: | ||
40 | + move['value'] = value | ||
41 | + | ||
42 | + print('done') | ||
43 | + memory.commit_ltmemory() | ||
44 | + return memory | ||
45 | + | ||
46 | + | ||
47 | +iteration = 0 | ||
48 | +cnn = ResidualCNN(config.REG_CONST, config.LEARNING_RATE, config.INPUT_SHAPE, config.NUM_ACTIONS, | ||
49 | + config.HIDDEN_CNN_LAYERS) | ||
50 | +current_player = Agent('current_player', 20, config.NUM_ACTIONS, config.MCTS_SIMS, config.CPUCT, cnn) | ||
51 | +memory = None | ||
52 | +run_folder = './run/' | ||
53 | +if not os.path.exists(run_folder): | ||
54 | + os.mkdir(run_folder) | ||
55 | + | ||
56 | +while 1: | ||
57 | + iteration += 1 | ||
58 | + print('ITERATION NUMBER ' + str(iteration)) | ||
59 | + memory = playMatches(current_player, config.EPISODES, turns_until_tau0=config.TURNS_UNTIL_TAU0, memory=memory) | ||
60 | + memory.clear_stmemory() | ||
61 | + | ||
62 | + if len(memory.ltmemory) >= config.MEMORY_SIZE: | ||
63 | + | ||
64 | + ######## RETRAINING ######## | ||
65 | + print('RETRAINING...') | ||
66 | + current_player.replay(memory.ltmemory) | ||
67 | + print('') | ||
68 | + | ||
69 | + if iteration % 5 == 0: | ||
70 | + pickle.dump(memory, open(run_folder + "memory/memory" + str(iteration).zfill(4) + ".p", "wb")) | ||
71 | + | ||
72 | + memory_samp = random.sample(memory.ltmemory, min(1000, len(memory.ltmemory))) | ||
73 | + | ||
74 | + else: | ||
75 | + print('MEMORY SIZE: ' + str(len(memory.ltmemory))) |
project/memory.py
0 → 100644
1 | +import numpy as np | ||
2 | +from collections import deque | ||
3 | + | ||
4 | +import config | ||
5 | + | ||
6 | + | ||
7 | +class Memory: | ||
8 | + def __init__(self, memory_size): | ||
9 | + self.MEMORY_SIZE = config.MEMORY_SIZE or memory_size | ||
10 | + self.ltmemory = deque(maxlen=self.MEMORY_SIZE) | ||
11 | + self.stmemory = deque(maxlen=self.MEMORY_SIZE) | ||
12 | + | ||
13 | + # 학습할 때 | ||
14 | + # 인덱스 | ||
15 | + # state_check | ||
16 | + | ||
17 | + # AV는 label | ||
18 | + | ||
19 | + def commit_stmemory(self, state, action_value): | ||
20 | + self.stmemory.append({ | ||
21 | + 'state': state.state_check | ||
22 | + , 'AV': action_value | ||
23 | + }) | ||
24 | + | ||
25 | + def commit_ltmemory(self): | ||
26 | + for i in self.stmemory: | ||
27 | + self.ltmemory.append(i) | ||
28 | + self.clear_stmemory() | ||
29 | + | ||
30 | + def clear_stmemory(self): | ||
31 | + self.stmemory = deque(maxlen=config.MEMORY_SIZE) |
project/similarity.py
0 → 100644
1 | +import keras | ||
2 | +import numpy as np | ||
3 | +from keras.applications.vgg16 import preprocess_input, VGG16 | ||
4 | +from keras.preprocessing import image | ||
5 | +from numpy.linalg import norm | ||
6 | + | ||
7 | +keras.backend.set_image_data_format('channels_first') | ||
8 | + | ||
9 | +def preprocess_image(img_path): | ||
10 | + img = image.load_img(img_path, target_size=(224, 224)) | ||
11 | + img_data = image.img_to_array(img, data_format='channels_first') | ||
12 | + img_data = np.expand_dims(img_data, axis=0) | ||
13 | + img_data = preprocess_input(img_data) | ||
14 | + return img_data | ||
15 | + | ||
16 | + | ||
17 | +class Similarity(object): | ||
18 | + def __init__(self, dest_file_path): | ||
19 | + self.model = VGG16(weights='imagenet', include_top=False) | ||
20 | + self.dest = preprocess_image(dest_file_path) | ||
21 | + | ||
22 | + def _feature(self, image): | ||
23 | + image_features = self.model.predict(image) | ||
24 | + image_features = image_features.reshape(7 * 7 * 512) | ||
25 | + return image_features | ||
26 | + | ||
27 | + def similarity(self, before_image): | ||
28 | + """ | ||
29 | + capture를 통해 얻은 이미지와 목표 이미지간 유사도를 계산 | ||
30 | + :return: 현재 이미지와 목표 이미지 유사도 | ||
31 | + """ | ||
32 | + before = self._feature(before_image) | ||
33 | + after = self._feature(self.dest) | ||
34 | + return np.dot(before, after) / (norm(before) * norm(after)) | ||
35 | +# image path | ||
36 | +# image.. |
project/state.py
0 → 100644
1 | +import base64 | ||
2 | +import config | ||
3 | +import numpy as np | ||
4 | +import copy | ||
5 | + | ||
6 | +from similarity import preprocess_image | ||
7 | + | ||
8 | + | ||
9 | +class State(object): | ||
10 | + def __init__(self, board, index, state_check, class_names, browser, sim_func): | ||
11 | + self.class_names = class_names # 클래스 이름들 | ||
12 | + self.board = board # 현재 상황 (set list) | ||
13 | + self.index = index # 현재 위치 | ||
14 | + self.id = self._convertStateToId() # state를 식별할 id | ||
15 | + self.state_check = state_check | ||
16 | + self.allowedActions = self._available_action() | ||
17 | + self.browser = browser | ||
18 | + self._make_css(state_check) | ||
19 | + self.sim_func = sim_func | ||
20 | + self.now_image = preprocess_image(self.browser.capture()) | ||
21 | + self.similarity = sim_func.similarity(self.now_image) | ||
22 | + # np.ones((len(self.board), config.NUM_ACTIONS), dtype=np.int) # 실제 확인 | ||
23 | + | ||
24 | + def _available_action(self): | ||
25 | + return np.nonzero(self.state_check[self.index])[0] | ||
26 | + | ||
27 | + def _convertStateToId(self): | ||
28 | + id = '{}'.format(self.index).join(map(lambda f: str(f) if len(f) != 0 else '{}', self.board)) | ||
29 | + b64 = base64.b64encode(id.encode('utf-8')) | ||
30 | + return str(b64) | ||
31 | + | ||
32 | + def _checkForEndGame(self): | ||
33 | + if self.index >= len(self.board) - 2: | ||
34 | + return True | ||
35 | + return False | ||
36 | + | ||
37 | + def takeAction(self, act): | ||
38 | + """ | ||
39 | + 액션대로 상태를 변경 | ||
40 | + :param act: | ||
41 | + :returns newState: | ||
42 | + value: | ||
43 | + done: | ||
44 | + """ | ||
45 | + if act >= config.NUM_ACTIONS: | ||
46 | + raise IndexError | ||
47 | + | ||
48 | + newBoard = copy.deepcopy(self.board) | ||
49 | + newStateCheck = copy.deepcopy(self.state_check) | ||
50 | + value = 0 | ||
51 | + done = 0 | ||
52 | + index = self.index | ||
53 | + if act == config.LAST_ACTION: | ||
54 | + index += 1 | ||
55 | + else: | ||
56 | + newBoard[index].add(act) | ||
57 | + newStateCheck[index][act] = 0 | ||
58 | + newState = State(newBoard, index, newStateCheck, self.class_names, self.browser, self.sim_func) | ||
59 | + if newState._checkForEndGame(): | ||
60 | + print('done') | ||
61 | + done = 1 | ||
62 | + if newState.similarity >= 0.8: | ||
63 | + done = 1 | ||
64 | + value = 1 | ||
65 | + print('good') | ||
66 | + | ||
67 | + return newState, value, done | ||
68 | + | ||
69 | + def _make_css(self, state_check): | ||
70 | + """ | ||
71 | + state_check로 css 작성 | ||
72 | + :return: file_name | ||
73 | + """ | ||
74 | + file_name = 'main.css' | ||
75 | + with open(file_name, 'w') as f: | ||
76 | + for class_index, actions in enumerate(state_check): | ||
77 | + class_index = int(class_index) | ||
78 | + f.write(".{}".format(self.class_names[class_index])) | ||
79 | + f.write("{") | ||
80 | + for action_index, action in enumerate(actions): | ||
81 | + if action == 0: | ||
82 | + f.write(" {}; ".format(config.ACTION_MEANING[action_index])) | ||
83 | + f.write("}\n") | ||
84 | + | ||
85 | + return file_name | ||
86 | + | ||
87 | + def reset(self): | ||
88 | + """ | ||
89 | + 상태 초기화 | ||
90 | + :return: | ||
91 | + """ | ||
92 | + self.index = 0 | ||
93 | + self.state_check = np.ones((len(self.board), config.NUM_ACTIONS)) |
-
Please register or login to post a comment