agent.py 6.31 KB
# %matplotlib inline

import numpy as np
import random
import config
import MCTS as mc

import time

from browser import Chrome
from similarity import Similarity
from state import State


class Agent():
    def __init__(self, name, state_size, action_size, mcts_simulations, cpuct, model=None):
        self.name = name
        self.action_size = action_size

        self.cpuct = cpuct
        self.similarity = Similarity('result.png')
        self.MCTSsimulations = mcts_simulations
        self.model = model
        self.browser = Chrome('kakao.html')
        self.mcts = None
        self.gameState = State([set() for i in range(20)],
                               0,
                               np.ones((20, config.NUM_ACTIONS), dtype=np.int),
                               ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
                                'main_summary', 'service_info', 'ico_corp', 'tit_black',
                                'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
                                'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
                               self.similarity)

        self.state_size = 20
        self.train_overall_loss = []
        self.train_value_loss = []
        self.train_policy_loss = []
        self.val_overall_loss = []
        self.val_value_loss = []
        self.val_policy_loss = []

    def simulate(self):

        ##### MOVE THE LEAF NODE
        leaf, value, done, breadcrumbs = self.mcts.moveToLeaf()
        ##### EVALUATE THE LEAF NODE
        value, breadcrumbs = self.evaluateLeaf(leaf, value, done, breadcrumbs)
        ##### BACKFILL THE VALUE THROUGH THE TREE
        self.mcts.back_fill(value, breadcrumbs)

    def act(self, state, tau):

        if self.mcts is None or state.id not in self.mcts.tree:
            self.buildMCTS(state)
        else:
            self.changeRootMCTS(state)

        #### run the simulation
        for sim in range(self.MCTSsimulations):
            self.simulate()

        #### get action values
        pi, values = self.getAV(1)

        ####pick the action
        action, value = self.chooseAction(pi, values, tau)
        print('action {} '.format(action))
        nextState, _, _ = state.takeAction(action)

        return action, pi, value

    def get_preds(self, state):
        # predict the leaf
        inputToModel = np.array([self.model.convertToModelInput(state)])
        preds = self.model.predict(inputToModel)
        logits_array = preds
        logits = logits_array[0]
        allowedActions = state.allowedActions

        mask = np.ones(config.NUM_ACTIONS, dtype=bool)
        mask[allowedActions] = False

        # sim = np.zeros(config.NUM_ACTIONS, dtype=np.float)
        # for act in allowedActions:
        #     new_state, _, _ = state.takeAction(act)
        #     sim[act] += new_state.similarity * 100
        # logits = sim
        # print('sim : {} '.format(logits))
        # print('probs : {} '.format(probs))
        # probs = np.ones(config.NUM_ACTIONS)

        logits[mask] = -100

        # SOFTMAX
        odds = np.exp(logits)
        probs = odds / np.sum(odds)  ###put this just before the for?

        return probs, allowedActions

    def evaluateLeaf(self, leaf, value, done, breadcrumbs):

        if done == 0:

            probs, allowedActions = self.get_preds(leaf.state)
            probs = probs[allowedActions]

            for idx, act in enumerate(allowedActions):
                newState, _, _ = leaf.state.takeAction(act)
                if newState.id not in self.mcts.tree:
                    node = mc.Node(newState)
                    self.mcts.add_node(node)
                else:
                    node = self.mcts.tree[newState.id]

                newEdge = mc.Edge(leaf, node, probs[idx], act)
                leaf.edges.append((act, newEdge))

        return (value, breadcrumbs)

    def getAV(self, tau):
        edges = self.mcts.root.edges
        pi = np.zeros(self.action_size, dtype=np.integer)
        values = np.zeros(self.action_size, dtype=np.float32)

        for action, edge in edges:
            pi[action] = pow(edge.stats['N'], 1 / tau)
            values[action] = edge.stats['Q']

        pi = pi / (np.sum(pi) * 1.0)
        return pi, values

    def chooseAction(self, pi, values, tau):
        if tau == 0:
            actions = np.argwhere(pi == max(pi))
            action = random.choice(actions)[0]
        else:
            action_idx = np.random.multinomial(1, pi)
            action = np.where(action_idx == 1)[0][0]
        value = values[action]
        return action, value

    def replay(self, ltmemory):

        for i in range(config.TRAINING_LOOPS):
            minibatch = random.sample(ltmemory, min(config.BATCH_SIZE, len(ltmemory)))

            training_states = np.array([self.model.convertToModelInput(row['state']) for row in minibatch])
            training_targets = np.array([row['AV'] for row in minibatch])

            fit = self.model.fit(training_states, training_targets, epochs=config.EPOCHS, verbose=1, validation_split=0,
                                 batch_size=32)
        time.sleep(1.0)
        print('\n')

    def predict(self, inputToModel):
        preds = self.model.predict(inputToModel)
        return preds

    def buildMCTS(self, state):
        self.root = mc.Node(state)
        self.mcts = mc.MCTS(self.root, self.cpuct)

    def changeRootMCTS(self, state):
        self.mcts.root = self.mcts.tree[state.id]

    def reset(self):
        self.gameState = State([set() for i in range(20)], 0,
                               np.ones((20, config.NUM_ACTIONS), dtype=np.int),
                               ['screen_out', 'tit_white', 'desc_newsfeed', 'hide', 'mArticle', 'txt_date',
                                'main_summary', 'service_info', 'ico_corp', 'tit_black',
                                'summary_info', 'link_newsfeed', 'img_thumb', 'now_stock', 'kakaoWrap', 'kakaoContent',
                                'main_service', 'num_stock', 'main_stock', 'tit_newsfeed'], self.browser,
                               self.similarity)
        return self.gameState

    def step(self, action):
        next_state, value, done = self.gameState.takeAction(action)
        self.gameState = next_state
        info = None
        return next_state, value, done, info