Search is the essence
of the mind.

© rick riordan

BUt always feel free to ask...

Humans

LOCATION

Botkina 4a,

51244 UA Kiev,

Ukraine

CONTACT

ai.info@arvilab.com

+380 (93) 750 7715

PRIVACY POLICY

Building your own Alpha Zero Part 2: Decisions

“I’ve looked forward in time to see all of the possible outcomes of a coming conflict.” - Doctor Stephen Vincent Strange, a superhero from Marvel Cinematic Universe.

If only we could predict all of the outcomes of our actions ahead of time! Unfortunately, we can’t, and neither can machines unless a task they are trying to solve is a really simple one. However, machines still can predict future states of the game a lot faster than any human. This ability enables AI to outperform humans in checkers, chess, go, and various other games. Alpha Zero utilizes the power of predictions as well using an algorithm called Monte Carlo Tree Search (MCTS).

This algorithm lets an AI look into the most promising of the possible futures instead of just trying to imagine every possible situation. Alpha Zero combines MCTS with a deep neural network capable of evaluation of a game state by which it can then determine these most promising futures. Given enough training data, this approach gives Alpha Zero the ability to efficiently find a best possible way to victory from any game state.

The best way to  learn how the MCTS works is to implement it. But, before we start our implementations, let’s do some basic coding first.

Games and Agent

We are going to define two simple interfaces in Python: a Game and an Agent. The Game interface lets us create a game session, get observations of a current state, take actions, and get the result of a game session. The Agent interface lets us predict a policy and a value from a current game state.

class Game():
    @abstractclassmethod
    def get_cur_player(self):
        """
        Returns:
            int: current player idx
        """
        pass
    
    @abstractclassmethod
    def get_action_size(self):
        """
        Returns:
            int: number of all possible actions
        """
        pass
        
    @abstractclassmethod
    def get_valid_actions(self, player):
        """
        Input:
            player: player

        Returns:
            validActions: a binary vector of length self.get_action_size(), 1 for
                        moves that are valid from the current board and player,
                        0 for invalid moves
        """
        pass
    
    @abstractclassmethod
    def take_action(self, action):
        """
        Input:
            action: action taken by the current player

        Returns:
            double: score of current player on the current turn
            int: player who plays in the next turn
        """
        pass
    
    @abstractclassmethod
    def get_observation_size(self):
        """
        Returns:
            (x,y,z): a tuple of observation dimensions
        """
        pass
        
    @abstractclassmethod
    def get_observation(self, player):
        """
        Input:
            player: current player

        Returns:
            observation matrix which will serve as an input to agent.predict
        """
        pass

    @abstractclassmethod
    def get_observation_str(self, observation):
        """
        Input:
            observation: observation

        Returns:
            string: a quick conversion of state to a string format.
                    Required by MCTS for hashing.
        """
        pass
    
    @abstractclassmethod
    def is_ended(self):
        """
        This method must return True if is_draw returns True
        Returns:
            boolean: False if game has not ended. True otherwise
        """
        pass

    @abstractclassmethod
    def is_draw(self):
        """
        Returns:
            boolean: True if game ended in a draw, False otherwise
        """
        pass
        
    @abstractclassmethod
    def get_score(self, player):
        """
        Input:
            player: current player

        Returns:
            double: reward in [-1, 1] for player if game has ended
        """
        pass
        
    @abstractclassmethod
    def clone(self):
        """
        Returns:
            Game: a deep clone of current Game object
        """
        pass
class Agent():

        @abstractclassmethod
        def predict(self, game, game_player):
            """
            Returns:
                policy, value: stochastic policy and a continuous value of a game observation 
            """
            pass

Take a minute or two to get familiar with these interfaces. When you are ready, we are going right into the implementation of the MCTS!

MCTS-based Agent

Now we are going to make an MCTS-based Agent. I know, it’s weird to implement an algorithm without explanations, but trust me on this one; it’s not as hard as it looks. Reading the implementation will give you a basic idea of how the algorithm works, and it will all make sense when I explain it all in a more formal way.

class AgentMCTS():
        
        NO_EXPLORATION = 0
        EXPLORATION_TEMP = 1
        
        def __init__(self, agent, exp_rate=EXPLORATION_TEMP, cpuct=1, num_simulations=10):
            self.agent = agent
            self.cpuct = cpuct
            self.exp_rate = exp_rate
            self.num_simulations = num_simulations
            self.Qsa = {}  # stores Q values for s,a
            self.Nsa = {}  # stores times edge s,a was visited
            self.Ns = {}  # stores times observation s was visited
            self.Ps = {}  # stores initial policy (returned by neural net)

        def predict(self, game, game_player):
            """
            Returns:
                pi, value: stochastic policy and a continuous value of a game observation 
            """
            observation = game.get_observation(game_player)
            observation_str = game.get_observation_str(observation)
            
            for i in range(self.num_simulations):
                game_clone = game.clone()
                _, value = self.search(game_clone, game_player)
            
            counts = [self.Nsa[(observation_str, a)] if (observation_str, a) in self.Nsa else 0 for a in
                  range(game.get_action_size())]

            if self.exp_rate == AgentMCTS.NO_EXPLORATION:
                bestA = np.argmax(counts)
                policy = [0] * len(counts)
                policy[bestA] = 1
            else:
                counts = [x ** (1. / self.exp_rate) for x in counts]
                policy = [x / float(sum(counts)) for x in counts]
    
            return policy, value
            
        def search(self, game, game_player, first_search=False):
            """
            Returns:
                is_draw, value: a draw flag and a continuous value of a game observation 
            """
            
            # check for a terminal state
            if game.is_ended():
                if game.is_draw():
                    return True, -1
                return False, game.get_score(game_player)
        
            observation = game.get_observation(game_player)
            observation_str = game.get_observation_str(observation)
            valid_actions = game.get_valid_actions(game_player)

            # check if this observation is a previously unknown state
            if (observation_str not in self.Ps):
                # get an initial estimation of a policy and a value with a neural network-based agent
                self.Ps[observation_str], value = self.agent.predict(game, game_player)
                self.Ps[observation_str] = self.Ps[observation_str] * valid_actions  # masking invalid moves
                self.Ps[observation_str] /= np.sum(self.Ps[observation_str])  # renormalize
                self.Ns[observation_str] = 0
                return False, value

            cur_best = -float('inf')
            best_act = -1

            # pick the action with the highest upper confidence bound
            for action in range(game.get_action_size()):
                if valid_actions[action]:
                    if (observation_str, action) in self.Qsa:
                        u = self.Qsa[(observation_str, action)] \
                            + self.cpuct * self.Ps[observation_str][action] * math.sqrt(self.Ns[observation_str]) / \
                                 (1 + self.Nsa[(observation_str, action)])
                    else:
                        u = self.cpuct * self.Ps[observation_str][action] * math.sqrt(self.Ns[observation_str] + EPS)
        
                    if u > cur_best:
                        cur_best = u
                        best_act = action

            # take the best action
            _, next_player = game.take_action(best_act)
            draw_result, value = self.search(game, next_player)

            # update values after search is done
            if game_player != next_player and not draw_result:
                value = -value

            if (observation_str, action) in self.Qsa:
                self.Qsa[(observation_str, action)] = (self.Nsa[(observation_str, action)] * self.Qsa[
                        (observation_str, action)] + value) / (self.Nsa[(observation_str, action)] + 1)
                self.Nsa[(observation_str, action)] += 1
            else:
                self.Qsa[(observation_str, action)] = value
                self.Nsa[(observation_str, action)] = 1
                    
            self.Ns[observation_str] += 1
            return False, value

Woah! That was a lot of code to process. It’s about time to explain what is going on here.

Monte Carlo Tree Search

MCTS is a policy search algorithm that balances exploration with exploitation to output an improved policy after a number of simulations of the game. MCTS builds a tree where nodes are different observations, and a directed edge exists between two nodes if a valid action can cause the state to transition from one node to another.

For each edge, we maintain a Q value denoted by Q(s, a) which is the expected reward for taking that action and N(s, a) which represents the number of times we took action a from state s across different simulations.

The hyperparameters of MCTS are:

  • number of simulations is the parameter which represents how many previously unexplored nodes we want to visit every time we call the agent.predict function;
  • cpuct is the parameter controlling the degree of exploration;
  • exploration rate is the parameter, controlling the distribution of the final policy. Setting it to a high value gives an almost uniform distribution while setting it to 0 makes us always select the best action.

The process of exploration is iterative. It starts with getting an observation from the game and getting an observation hash using the method:

game.get_observation_str(observation):

    observation = game.get_observation(game_player)
    observation_str = game.get_observation_str(observation)

We iterate over a number of simulations, each time performing exploration of our tree until we find a previously unexplored or a terminal node using the search method. Note that we should pass a copy of a current game state to the search method, as it changes the state of the game during the exploration process.

  for i in range(self.num_simulations):
        game_clone = game.clone()
        _, value = self.search(game_clone, game_player)

So, the search begins…

If we reached a terminal node, we just need to propagate a value for the current player:
So, the search begins…

If we reached a terminal node, we just need to propagate a value for the current player:

if game.is_ended():
        if game.is_draw():
            return True, -1
        return False, game.get_score(game_player)

If we reached a previously unexplored node, we get P(s) which is the prior probability of taking a particular action from state s according to the policy returned by our neural network. Note that the neural network Agent uses the same interface Agent we defined earlier.

    observation = game.get_observation(game_player)
    observation_str = game.get_observation_str(observation)
    valid_actions = game.get_valid_actions(game_player)

    # check if this observation is a previously unknown state
    if (observation_str not in self.Ps):
        # get an initial estimation of a policy and a value with a neural network-based agent
        self.Ps[observation_str], value = self.agent.predict(game, game_player)
        self.Ps[observation_str] = self.Ps[observation_str] * valid_actions  # masking invalid moves
        self.Ps[observation_str] /= np.sum(self.Ps[observation_str])  # renormalize
        self.Ns[observation_str] = 0
        return False, value

If we encountered a known state, we calculate U(s, a) values for every edge (action->state transition), which is an upper confidence bound on the Q value of our edge. These values are calculated as:

  cur_best = -float('inf')
    best_act = -1

    # pick the action with the highest upper confidence bound
    for action in range(game.get_action_size()):
        if valid_actions[action]:
            if (observation_str, action) in self.Qsa:
                u = self.Qsa[(observation_str, action)] \
                        + self.cpuct * self.Ps[observation_str][action] * math.sqrt(self.Ns[observation_str]) / \
                            (1 + self.Nsa[(observation_str, action)])
            else:
                u = self.cpuct * self.Ps[observation_str][action] * math.sqrt(self.Ns[observation_str] + EPS)
        
            if u > cur_best:
                cur_best = u
                best_act = action

Once we finished our MCTS simulations, N(s, a) values provide a good approximation for the policy from each state. The only thing left is to apply our exploration rate parameter to control the distribution of the policy values:

   counts = [self.Nsa[(observation_str, a)] if (observation_str, a) in self.Nsa else 0 for a in
                  range(game.get_action_size())]

    if self.exp_rate == AgentMCTS.NO_EXPLORATION:
        bestA = np.argmax(counts)
        policy = [0] * len(counts)
        policy[bestA] = 1
    else:
        counts = [x ** (1. / self.exp_rate) for x in counts]
        policy = [x / float(sum(counts)) for x in counts]
    
    return policy, value

 

Congratulations! You now understand now how the Monte Carlo Tree Search works! However, there is one more thing left, before we’ll be able to make it work.

Neural Network

Do you remember these lines from our AgentMCTS?

   def __init__(self, agent, exp_rate=EXPLORATION_TEMP, cpuct=1, num_simulations=10):
        self.agent = agent
   self.Ps[observation_str], value = self.agent.predict(game, game_player)

That is exactly how we are going to use our neural network here. It produces policy and value, so the Agent interface may be reused with the neural network too. Let’s write a neural-network based Agent! We’ll use the Keras library to make our code as clean as possible. First, we define our abstract neural network:

class NNet():
    def __init__(self, observation_size_x, observation_size_y, observation_size_z, action_size):
        self.observation_size_x = observation_size_x
        self.observation_size_y = observation_size_y
        self.observation_size_z = observation_size_z
        self.action_size = action_size
        
        self.model = self.build_model()
        
        self.graph = tf.get_default_graph()
    
    def build_model():
        '''
        Returns:
            model: a Keras model. Model gets observation with shape (self.observation_size_x, self.observation_size_y,  self.observation_size_z) and outputs a policy vector with size self.action_size and a value
        '''
        pass
        
    def predict(self, observation):
        with self.graph.as_default():
            self.model._make_predict_function()
            pi, v = self.model.predict(observation)

            if np.isscalar(v[0]):
                return pi[0], v[0]
            else:
                return pi[0], v[0][0]

And now we define our Agent which will use the neural network:

class AgentNNet(Agent):
    def __init__(self, nnet):
        self.nnet = nnet

    def predict(self, game, game_player):
        observation = game.get_observation(game_player)
        observation = observation[np.newaxis, :, :]

        return self.nnet.predict(observation)

 

In case you are wondering how to make neural network output a policy and a value, here is an example of a simple, convolutional neural net for Alpha Zero:

class ConvNNet(NNet):

    def build_model(self):
        num_channels = 512
        learning_rate = 0.0001

        input_boards = Input(shape=(self.observation_size_x, self.observation_size_y, self.observation_size_z))  # s: batch_size x board_x x board_y

        x_image = Reshape((self.observation_size_x, self.observation_size_y, self.observation_size_z))(input_boards)  # batch_size  x board_x x board_y x 1

        h_conv1 = Activation('relu')(BatchNormalization(axis=3)(
            Conv2D(num_channels, 3, padding='same')(x_image)))  # batch_size  x board_x x board_y x num_channels
        h_conv2 = Activation('relu')(BatchNormalization(axis=3)(
            Conv2D(num_channels, 3, padding='same')(h_conv1)))  # batch_size  x board_x x board_y x num_channels
        h_conv3 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(num_channels, 3, padding='valid')(
            h_conv2)))  # batch_size  x (board_x-2) x (board_y-2) x num_channels
        h_conv4 = Activation('relu')(BatchNormalization(axis=3)(Conv2D(num_channels, 3, padding='valid')(
            h_conv3)))  # batch_size  x (board_x-4) x (board_y-4) x num_channels

        h_conv4_flat = Flatten()(h_conv4)

        s_fc1 = Dropout(0.3)(
            Activation('relu')(BatchNormalization(axis=1)(Dense(1024)(h_conv4_flat))))  # batch_size x 1024
        s_fc2 = Dropout(0.3)(
            Activation('relu')(BatchNormalization(axis=1)(Dense(512)(s_fc1))))  # batch_size x 1024

        pi = Dense(self.action_size, activation='softmax', name='pi')(s_fc2)  # batch_size x self.action_size
        v = Dense(1, activation='tanh', name='v')(s_fc2)  # batch_size x 1

        model = Model(inputs=input_boards, outputs=[pi, v])

        model.compile(loss=['categorical_crossentropy', 'mean_squared_error'], optimizer=Adam(learning_rate))

        return model

 

Bringing it all together

Today we've defined our core interfaces: Game, Agent and NNet. We implemented a neural-network based Agent and a Monte Carlo Tree Search - based Agent! So, to initialize our Monte Carlo Agent we just need to write the code: 

    game = GameImplementation()
    observation_size = game.get_observation_size()

    nnet = ConvNNet(observation_size[0], observation_size[1], observation_size[2], game.get_action_size())
    agent_nnet = AgentNNet(nnet)

    agent_mcts = AgentMCTS(agent_nnet)

 

And our agent_mcts might predict the policy and value of a game state like this:

    policy, value = agent_mcts.predict(game, game.get_cur_player())

 

What's next?

Awesome! We now have our core classes almost ready for writing an Alpha Zero training pipeline. Next time we are going to show you how to train an agent to play English draughts using Monte Carlo Tree Search and a deep residual network. The game has complexity of roughly 500,995,484,682,338,672,639 possible positions, so the task is going to be a fun one. 

Stay tuned!

PRIVACY POLICY