seongahjo

Add 코드

No preview for this file type
No preview for this file type
This file is too large to display.
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
.idea
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
File mode changed
import tensorflow as tf
def softmax_cross_entropy_with_logits(y_true, y_pred):
p = y_pred
pi = y_true
zero = tf.zeros(shape = tf.shape(pi), dtype=tf.float32)
where = tf.equal(pi, zero)
negatives = tf.fill(tf.shape(pi), -100.0)
p = tf.where(where, negatives, p)
loss = tf.nn.softmax_cross_entropy_with_logits(labels = pi, logits = p)
return loss
import config
import numpy as np
from keras.models import load_model, Model
from keras.layers import Input, Dense, Conv1D, Flatten, BatchNormalization, LeakyReLU, add
from keras.optimizers import SGD
from keras import regularizers
from DNN.loss import softmax_cross_entropy_with_logits
class GeneralModel(object):
def __init__(self, reg_const, learning_rate, input_dim, output_dim):
self.reg_const = reg_const
self.learning_rate = learning_rate
self.input_dim = input_dim
self.output_dim = output_dim
def predict(self, x):
return self.model.predict(x)
def fit(self, states, targets, epochs, verbose, validation_split, batch_size):
return self.model.fit(states, targets, epochs=epochs, verbose=verbose, validation_split=validation_split,
batch_size=batch_size)
def write(self, game, version):
self.model.save('models/version' + "{0:0>4}".format(version) + '.h5')
def read(self, game, run_number, version):
return load_model(str(run_number).zfill(4) + "/models/version" + "{0:0>4}".format
(version) + '.h5', custom_objects={'softmax_cross_entropy_with_logits': softmax_cross_entropy_with_logits})
class ResidualCNN(GeneralModel):
def __init__(self, reg_const, learning_rate, input_dim, output_dim, hidden_layers):
GeneralModel.__init__(self, reg_const, learning_rate, input_dim, output_dim)
self.hidden_layers = hidden_layers
self.num_layers = len(hidden_layers)
self.model = self._build_model()
def residual_layer(self, input_block, filters, kernel_size):
x = self.conv_layer(input_block, filters, kernel_size)
x = Conv1D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_last"
, padding='same'
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
)(x)
x = BatchNormalization(axis=1)(x)
x = add([input_block, x])
x = LeakyReLU()(x)
return x
def conv_layer(self, x, filters, kernel_size):
x = Conv1D(
filters=filters
, kernel_size=kernel_size
, data_format="channels_last"
, padding='same'
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
)(x)
x = BatchNormalization(axis=1)(x)
x = LeakyReLU()(x)
return x
def value_head(self, x):
x = Conv1D(
filters=2
, kernel_size=1
, data_format="channels_last"
, padding='same'
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
)(x)
x = BatchNormalization(axis=1)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dense(
self.output_dim
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
, name='value_head'
)(x)
return x
def policy_head(self, x):
x = Conv1D(
filters=2
, kernel_size=1
, data_format="channels_last"
, padding='same'
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
)(x)
x = BatchNormalization(axis=1)(x)
x = LeakyReLU()(x)
x = Flatten()(x)
x = Dense(
self.output_dim
, use_bias=False
, activation='linear'
, kernel_regularizer=regularizers.l2(self.reg_const)
, name='policy_head'
)(x)
return x
def _build_model(self):
# image shape
# 그냥 배열 shape
main_input = Input(shape=self.input_dim, name='main_input')
x = self.conv_layer(main_input, self.hidden_layers[0]['filters'], self.hidden_layers[0]['kernel_size'])
if len(self.hidden_layers) > 1:
for h in self.hidden_layers[1:]:
x = self.residual_layer(x, h['filters'], h['kernel_size'])
vh = self.value_head(x)
ph = self.policy_head(x)
model = Model(inputs=[main_input], outputs=[ph])
model.compile(loss=softmax_cross_entropy_with_logits,
optimizer=SGD(lr=self.learning_rate, momentum=config.MOMENTUM),
loss_weights=[0.5]
)
return model
def convertToModelInput(self, state):
# [20,20]
inputToModel = state.state_check # np.append(state.binary, [(state.playerTurn + 1)/2] * self.input_dim[1] * self.input_dim[2])
inputToModel = np.reshape(inputToModel, self.input_dim)
return inputToModel
MIT License
Copyright (c) 2019 seongahjo
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
import numpy as np
import logging
import config
class Node(object):
def __init__(self, state):
self.state = state
self.id = state.id
self.edges = []
def is_leaf(self):
if len(self.edges) > 0:
return False
else:
return True
class Edge(object):
def __init__(self, in_node, out_node, prior, action):
self.id = in_node.state.id + '|' + out_node.state.id
self.inNode = in_node
self.outNode = out_node
self.action = action
self.stats = {
'N': 0,
'W': 0,
'Q': 0,
'P': prior,
}
# N 이 횟수
# W 가 win
# Q 가 승률
# P 가
class MCTS(object):
def __init__(self, root, cpuct):
self.root = root
self.tree = {}
self.cpuct = cpuct
self.add_node(root)
def __len__(self):
return len(self.tree)
def moveToLeaf(self):
breadcrumbs = []
current_node = self.root
done = 0
value = 0
while not current_node.is_leaf():
max_qu = -99999
if current_node == self.root:
epsilon = config.EPSILON
nu = np.random.dirichlet([config.ALPHA] * len(current_node.edges))
else:
epsilon = 0
nu = [0] * len(current_node.edges)
Nb = 0
for action, edge in current_node.edges:
Nb = Nb + edge.stats['N']
for idx, (action, edge) in enumerate(current_node.edges):
U = self.cpuct * \
((1 - epsilon) * edge.stats['P'] + epsilon * nu[idx]) * \
np.sqrt(Nb) / (1 + edge.stats['N'])
Q = edge.stats['Q']
if Q + U > max_qu:
max_qu = Q + U
simulation_action = action
simulation_edge = edge
new_state, value, done = current_node.state.takeAction(simulation_action)
current_node = simulation_edge.outNode
breadcrumbs.append(simulation_edge)
return current_node, value, done, breadcrumbs
def back_fill(self, value, breadcrumbs):
for edge in breadcrumbs:
edge.stats['N'] = edge.stats['N'] + 1
edge.stats['W'] = edge.stats['W'] + value
edge.stats['Q'] = edge.stats['W'] / edge.stats['N']
def add_node(self, node):
self.tree[node.id] = node
## D.sAIgn
Web Design Automation from screenshot by Reinforcement Learning
**Now only train is availble**
## Feature
* Monte Carlo Tree Search
* DQN
## Installation
* `Python 3`
## Usage
* configure `config.py`
* `python main.py`
## Roadmap
- [ ] run D.sAIgn using trained model
- [ ] see remarkable result
- [ ] fast train by multiprocessing
- [ ] find majorly used css properties
- [ ] offer pretrained model
## License
D.sAIgn is released under [MIT License]
[MIT License]:https://github.com/seongahjo/D.sAIgn/blob/master/LICENSE
\ No newline at end of file
__all__ = ['config']
\ No newline at end of file
# %matplotlib inline
import numpy as np
import random
import config
import MCTS as mc
import time
from browser import Chrome
from similarity import Similarity
from state import State
class Agent():
def __init__(self, name, state_size, action_size, mcts_simulations, cpuct, model=None):
self.name = name
self.action_size = action_size
self.cpuct = cpuct
self.similarity = Similarity('result.png')
self.MCTSsimulations = mcts_simulations
self.model = model
self.browser = Chrome('kakao.html')
self.mcts = None
self.gameState = State([set() for i in range(20)],
0,
np.ones((20, config.NUM_ACTIONS), dtype=np.int),
['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
'main_summary', 'service_info', 'ico_corp', 'tit_black',
'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
self.similarity)
self.state_size = 20
self.train_overall_loss = []
self.train_value_loss = []
self.train_policy_loss = []
self.val_overall_loss = []
self.val_value_loss = []
self.val_policy_loss = []
def simulate(self):
##### MOVE THE LEAF NODE
leaf, value, done, breadcrumbs = self.mcts.moveToLeaf()
##### EVALUATE THE LEAF NODE
value, breadcrumbs = self.evaluateLeaf(leaf, value, done, breadcrumbs)
##### BACKFILL THE VALUE THROUGH THE TREE
self.mcts.back_fill(value, breadcrumbs)
def act(self, state, tau):
if self.mcts is None or state.id not in self.mcts.tree:
self.buildMCTS(state)
else:
self.changeRootMCTS(state)
#### run the simulation
for sim in range(self.MCTSsimulations):
self.simulate()
#### get action values
pi, values = self.getAV(1)
####pick the action
action, value = self.chooseAction(pi, values, tau)
print('action {} '.format(action))
nextState, _, _ = state.takeAction(action)
return action, pi, value
def get_preds(self, state):
# predict the leaf
inputToModel = np.array([self.model.convertToModelInput(state)])
preds = self.model.predict(inputToModel)
logits_array = preds
logits = logits_array[0]
allowedActions = state.allowedActions
mask = np.ones(config.NUM_ACTIONS, dtype=bool)
mask[allowedActions] = False
# sim = np.zeros(config.NUM_ACTIONS, dtype=np.float)
# for act in allowedActions:
# new_state, _, _ = state.takeAction(act)
# sim[act] += new_state.similarity * 100
# logits = sim
# print('sim : {} '.format(logits))
# print('probs : {} '.format(probs))
# probs = np.ones(config.NUM_ACTIONS)
logits[mask] = -100
# SOFTMAX
odds = np.exp(logits)
probs = odds / np.sum(odds) ###put this just before the for?
return probs, allowedActions
def evaluateLeaf(self, leaf, value, done, breadcrumbs):
if done == 0:
probs, allowedActions = self.get_preds(leaf.state)
probs = probs[allowedActions]
for idx, act in enumerate(allowedActions):
newState, _, _ = leaf.state.takeAction(act)
if newState.id not in self.mcts.tree:
node = mc.Node(newState)
self.mcts.add_node(node)
else:
node = self.mcts.tree[newState.id]
newEdge = mc.Edge(leaf, node, probs[idx], act)
leaf.edges.append((act, newEdge))
return (value, breadcrumbs)
def getAV(self, tau):
edges = self.mcts.root.edges
pi = np.zeros(self.action_size, dtype=np.integer)
values = np.zeros(self.action_size, dtype=np.float32)
for action, edge in edges:
pi[action] = pow(edge.stats['N'], 1 / tau)
values[action] = edge.stats['Q']
pi = pi / (np.sum(pi) * 1.0)
return pi, values
def chooseAction(self, pi, values, tau):
if tau == 0:
actions = np.argwhere(pi == max(pi))
action = random.choice(actions)[0]
else:
action_idx = np.random.multinomial(1, pi)
action = np.where(action_idx == 1)[0][0]
value = values[action]
return action, value
def replay(self, ltmemory):
for i in range(config.TRAINING_LOOPS):
minibatch = random.sample(ltmemory, min(config.BATCH_SIZE, len(ltmemory)))
training_states = np.array([self.model.convertToModelInput(row['state']) for row in minibatch])
training_targets = np.array([row['AV'] for row in minibatch])
fit = self.model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0,
batch_size=32)
time.sleep(1.0)
print('\n')
def predict(self, inputToModel):
preds = self.model.predict(inputToModel)
return preds
def buildMCTS(self, state):
self.root = mc.Node(state)
self.mcts = mc.MCTS(self.root, self.cpuct)
def changeRootMCTS(self, state):
self.mcts.root = self.mcts.tree[state.id]
def reset(self):
self.gameState = State([set() for i in range(20)], 0,
np.ones((20, config.NUM_ACTIONS), dtype=np.int),
['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
'main_summary', 'service_info', 'ico_corp', 'tit_black',
'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
self.similarity)
return self.gameState
def step(self, action):
next_state, value, done = self.gameState.takeAction(action)
self.gameState = next_state
info = None
return next_state, value, done, info
import datetime
import os
from selenium import webdriver
from selenium.common.exceptions import WebDriverException
import config
class Browser(object):
def __init__(self):
pass
def capture(self):
pass
class Chrome(Browser):
def __init__(self, html):
self.html = html
self.browser = self._runChrome()
def _runChrome(self):
options = webdriver.ChromeOptions()
options.add_argument('headless')
options.add_argument('no-sandbox')
driver = webdriver.Chrome(options=options)
driver.set_window_size(config.WIDTH, config.HEIGHT)
return driver
def capture(self):
"""
현재 상태를 크롬으로 캡쳐한 파일 반환
:return: image_path
"""
date_time = datetime.datetime.now()
date_time = date_time.strftime('%Y-%m-%d-%H-%M-%S')
file_path = 'image/' + str(date_time) + '.png'
now_path = 'file://' + os.path.join(os.getcwd(), self.html)
try:
self.browser.get(now_path)
except WebDriverException:
self.browser = self._runChrome()
self.browser.get(now_path)
self.browser.save_screenshot(file_path)
return file_path
#### SELF PLAY
EPISODES = 30
MCTS_SIMS = 50
MEMORY_SIZE = 30000
TURNS_UNTIL_TAU0 = 10 # turn on which it starts playing deterministically
CPUCT = 1
EPSILON = 0.2
ALPHA = 0.8
INPUT_SHAPE = (20, 21)
#### RETRAINING
BATCH_SIZE = 256
EPOCHS = 50
REG_CONST = 0.0001
LEARNING_RATE = 0.1
MOMENTUM = 0.9
TRAINING_LOOPS = 10
HIDDEN_CNN_LAYERS = [
{'filters': 75, 'kernel_size': 4}
, {'filters': 75, 'kernel_size': 4}
, {'filters': 75, 'kernel_size': 4}
, {'filters': 75, 'kernel_size': 4}
, {'filters': 75, 'kernel_size': 4}
, {'filters': 75, 'kernel_size': 4}
]
HIDDEN_2D_CNN_LAYERS = [
{'filters': 75, 'kernel_size': (4, 4)}
, {'filters': 75, 'kernel_size': (4, 4)}
, {'filters': 75, 'kernel_size': (4, 4)}
, {'filters': 75, 'kernel_size': (4, 4)}
, {'filters': 75, 'kernel_size': (4, 4)}
, {'filters': 75, 'kernel_size': (4, 4)}
]
#### EVALUATION
EVAL_EPISODES = 20
SCORING_THRESHOLD = 1.3
ACTION_MEANING = {
0: 'overflow: hidden; position: absolute; width: 0; height: 0; line-height: 0; text-indent: -9999px',
1: 'position: absolute; top: 15px; left: 20px; color: #fff; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif',
2: 'padding-top: 8px; font-size: 14px; line-height: 22px; color: #666; display: -webkit-box; overflow: hidden; max-height: 66px; -webkit-box-orient: vertical; -webkit-line-clamp: 3',
3: 'display: none',
4: 'position: relative; width: 1160px; margin: 0 auto; padding-bottom: 180px; width: 100%',
5: 'display: block; padding-top: 28px; font-size: 12px; clear: both',
6: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 35px 60px 0; padding-left: 20px; padding-right: 20px',
7: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box',
8: 'margin: 22px 0 0 21px; width: 18px; height: 14px; margin-top: 21px; background-position: -40px 0; display: inline-block; background-position: -70px 0; width: 17px; height: 3px; margin: 20px 0 0 4px; background-position: -90px 0',
9: 'position: absolute; top: 15px; left: 20px; color: #1e1e1e; font-weight: normal; font-size: 14px; font-family: "NotoSans Regular", "Malgun Gothic", "맑은 고딕", "Apple SD Gothic Neo", "돋움", dotum, sans-serif',
10: 'float: left; width: 33.33%; padding: 0 20px; box-sizing: border-box',
11: 'display: block; position: relative',
12: 'width: 100%; margin-bottom: 14px',
13: 'float: left',
14: 'position: relative; min-width: 1320px; background-color: #fff',
15: 'padding-top: 122px; padding-top: 0',
16: 'overflow: hidden; max-width: 1800px; margin: 0 auto; padding: 76px 60px 0; padding-left: 20px; padding-right: 20px;',
17: 'font-weight: bold; font-size: 28px; line-height: 45px; padding-left: 2px; color: #003cff; font-size: 32px; font-weight: bold; color: #1e1e1e; display: inline-block; font-size: 20px; vertical-align: top',
18: 'padding-top: 33px; color: #1e1e1e',
19: 'text-decoration: underline; display: block; padding-top: 18px; font-size: 20px; line-height: 32px; color: #1e1e1e',
}
WIDTH = 800
HEIGHT = 600
IMAGE_SIZE = (224, 224)
LAST_ACTION = len(ACTION_MEANING)
NUM_ACTIONS = len(ACTION_MEANING) + 1
import pickle
import random
from DNN.model import ResidualCNN
from agent import Agent
import config
import os
def playMatches(web, EPISODES, turns_until_tau0, memory=None):
# scores = {player1.name:0, "drawn": 0, player2.name:0}
# sp_scores = {'sp':0, "drawn": 0, 'nsp':0}
# points = {player1.name:[], player2.name:[]}
for e in range(EPISODES):
state = web.reset()
done = 0
web.mcts = None
turn = 0
while done == 0:
turn = turn + 1
#### Run the MCTS algo and return an action
action, pi, value = web.act(state, 1)
if memory is not None:
# Commit the move to memory
memory.commit_stmemory(state, pi)
# Do the action
state, value, done, _ = web.step(action)
if done == 1:
if memory is not None:
# If the game is finished, assign the values correctly to the game moves
for move in memory.stmemory:
move['value'] = value
print('done')
memory.commit_ltmemory()
return memory
iteration = 0
cnn = ResidualCNN(config.REG_CONST, config.LEARNING_RATE, config.INPUT_SHAPE, config.NUM_ACTIONS,
config.HIDDEN_CNN_LAYERS)
current_player = Agent('current_player', 20, config.NUM_ACTIONS, config.MCTS_SIMS, config.CPUCT, cnn)
memory = None
run_folder = './run/'
if not os.path.exists(run_folder):
os.mkdir(run_folder)
while 1:
iteration += 1
print('ITERATION NUMBER ' + str(iteration))
memory = playMatches(current_player, config.EPISODES, turns_until_tau0=config.TURNS_UNTIL_TAU0, memory=memory)
memory.clear_stmemory()
if len(memory.ltmemory) >= config.MEMORY_SIZE:
######## RETRAINING ########
print('RETRAINING...')
current_player.replay(memory.ltmemory)
print('')
if iteration % 5 == 0:
pickle.dump(memory, open(run_folder + "memory/memory" + str(iteration).zfill(4) + ".p", "wb"))
memory_samp = random.sample(memory.ltmemory, min(1000, len(memory.ltmemory)))
else:
print('MEMORY SIZE: ' + str(len(memory.ltmemory)))
import numpy as np
from collections import deque
import config
class Memory:
def __init__(self, memory_size):
self.MEMORY_SIZE = config.MEMORY_SIZE or memory_size
self.ltmemory = deque(maxlen=self.MEMORY_SIZE)
self.stmemory = deque(maxlen=self.MEMORY_SIZE)
# 학습할 때
# 인덱스
# state_check
# AV는 label
def commit_stmemory(self, state, action_value):
self.stmemory.append({
'state': state.state_check
, 'AV': action_value
})
def commit_ltmemory(self):
for i in self.stmemory:
self.ltmemory.append(i)
self.clear_stmemory()
def clear_stmemory(self):
self.stmemory = deque(maxlen=config.MEMORY_SIZE)
import keras
import numpy as np
from keras.applications.vgg16 import preprocess_input, VGG16
from keras.preprocessing import image
from numpy.linalg import norm
keras.backend.set_image_data_format('channels_first')
def preprocess_image(img_path):
img = image.load_img(img_path, target_size=(224, 224))
img_data = image.img_to_array(img, data_format='channels_first')
img_data = np.expand_dims(img_data, axis=0)
img_data = preprocess_input(img_data)
return img_data
class Similarity(object):
def __init__(self, dest_file_path):
self.model = VGG16(weights='imagenet', include_top=False)
self.dest = preprocess_image(dest_file_path)
def _feature(self, image):
image_features = self.model.predict(image)
image_features = image_features.reshape(7 * 7 * 512)
return image_features
def similarity(self, before_image):
"""
capture를 통해 얻은 이미지와 목표 이미지간 유사도를 계산
:return: 현재 이미지와 목표 이미지 유사도
"""
before = self._feature(before_image)
after = self._feature(self.dest)
return np.dot(before, after) / (norm(before) * norm(after))
# image path
# image..
import base64
import config
import numpy as np
import copy
from similarity import preprocess_image
class State(object):
def __init__(self, board, index, state_check, class_names, browser, sim_func):
self.class_names = class_names # 클래스 이름들
self.board = board # 현재 상황 (set list)
self.index = index # 현재 위치
self.id = self._convertStateToId() # state를 식별할 id
self.state_check = state_check
self.allowedActions = self._available_action()
self.browser = browser
self._make_css(state_check)
self.sim_func = sim_func
self.now_image = preprocess_image(self.browser.capture())
self.similarity = sim_func.similarity(self.now_image)
# np.ones((len(self.board), config.NUM_ACTIONS), dtype=np.int) # 실제 확인
def _available_action(self):
return np.nonzero(self.state_check[self.index])[0]
def _convertStateToId(self):
id = '{}'.format(self.index).join(map(lambda f: str(f) if len(f) != 0 else '{}', self.board))
b64 = base64.b64encode(id.encode('utf-8'))
return str(b64)
def _checkForEndGame(self):
if self.index >= len(self.board) - 2:
return True
return False
def takeAction(self, act):
"""
액션대로 상태를 변경
:param act:
:returns newState:
value:
done:
"""
if act >= config.NUM_ACTIONS:
raise IndexError
newBoard = copy.deepcopy(self.board)
newStateCheck = copy.deepcopy(self.state_check)
value = 0
done = 0
index = self.index
if act == config.LAST_ACTION:
index += 1
else:
newBoard[index].add(act)
newStateCheck[index][act] = 0
newState = State(newBoard, index, newStateCheck, self.class_names, self.browser, self.sim_func)
if newState._checkForEndGame():
print('done')
done = 1
if newState.similarity >= 0.8:
done = 1
value = 1
print('good')
return newState, value, done
def _make_css(self, state_check):
"""
state_check로 css 작성
:return: file_name
"""
file_name = 'main.css'
with open(file_name, 'w') as f:
for class_index, actions in enumerate(state_check):
class_index = int(class_index)
f.write(".{}".format(self.class_names[class_index]))
f.write("{")
for action_index, action in enumerate(actions):
if action == 0:
f.write(" {}; ".format(config.ACTION_MEANING[action_index]))
f.write("}\n")
return file_name
def reset(self):
"""
상태 초기화
:return:
"""
self.index = 0
self.state_check = np.ones((len(self.board), config.NUM_ACTIONS))