예제 #1
0
        /*
         * Phase 1: Selection
         * Select until EITHER not fully expanded OR leaf node
         */
        public MonteCarloNode Select(GameState state)
        {
            if (!this.nodes.ContainsKey(state.GetId()))
            {
                UnityEngine.Debug.LogError("Key not found in the map: " + String.Join(",", this.nodes.Keys) + ", key = " + state.GetId());
            }
            MonteCarloNode node = this.nodes[state.GetId()];

            while (node.IsFullyExpanded() && !node.IsLeaf())
            {
                List <int> actions    = node.AllActions();
                int        bestAction = -1;
                double     bestUCB1   = Double.NegativeInfinity;

                foreach (int action in actions)
                {
                    double childUCB1 = node.ChildNode(action).GetUCB1(this.UCB1ExploreParam);
                    if (childUCB1 > bestUCB1)
                    {
                        bestAction = action;
                        bestUCB1   = childUCB1;
                    }
                }
                node = node.ChildNode(bestAction);
            }
            return(node);
        }
예제 #2
0
        /*
         * Phase 3: Simulation
         * From given node, play the game until a terminal state, then return winner
         */
        public PlayerKind Simulate(MonteCarloNode node, int maxDepth)
        {
            GameState  state  = node.state;
            IPlayer    player = state.GetPlayer().ClonePlayer();
            IPlayer    enemy  = state.GetEnemy().ClonePlayer();
            PlayerKind turn   = state.GetTurn();
            PlayerKind winner;
            int        depth = 0;

            // Continue until someone wins, specific depth is reached or time is up
            while ((winner = this.game.Winner(player, enemy)) == PlayerKind.NONE && depth < maxDepth)
            {
                int action = this.game.RandomLegalAction(player, enemy, turn);
                turn = this.game.MakeAction(player, enemy, action, turn);
                this.game.UpdatePlayers(player, enemy);
                depth++;
            }

            // Calculate winner manually if no one wins
            if (winner == PlayerKind.NONE)
            {
                winner = CalculateWinner(state.GetPlayer(), player, state.GetEnemy(), enemy);
            }

            return(winner);
        }
예제 #3
0
 public void MakeNode(GameState state)
 {
     if (!this.nodes.ContainsKey(state.GetId()))
     {
         bool[]         unexpandedActions = this.game.LegalActions(state);
         MonteCarloNode node = new MonteCarloNode(null, -1, state, unexpandedActions);
         this.nodes[state.GetId()] = node;
     }
 }
예제 #4
0
        /*
         * Get the MonteCarloNode corresponding to the given action.
         */
        public MonteCarloNode ChildNode(int action)
        {
            MonteCarloNode child = this.children[action];

            if (child == null)
            {
                throw new Exception("Child not expanded or no such action!");
            }

            return(child);
        }
예제 #5
0
        /*
         * Expand the specified child action and return the new child node.
         * Add the node to the array of children nodes.
         * Remove the play from the array of unexpanded actions.
         */
        public MonteCarloNode Expand(int action, GameState childState, bool[] allActions)
        {
            if (!this.children.ContainsKey(action))
            {
                throw new Exception("No such action!");
            }
            MonteCarloNode childNode = new MonteCarloNode(this, action, childState, allActions);

            this.children[action] = childNode;

            return(childNode);
        }
예제 #6
0
 /*
  * Phase 4: Backpropagation
  * From given node, propagate plays and winner to ancestors' statistics
  */
 public void Backpropagate(MonteCarloNode node, PlayerKind winner)
 {
     while (node != null)
     {
         node.numberOfPlays++;
         if (winner == PlayerKind.PLAYER)
         {
             node.numberOfWins++;
         }
         // Update parent
         node = node.parent;
     }
 }
예제 #7
0
        /*
         * Phase 2: Expansion
         * Of the given node, expand a random unexpanded child node
         */
        public MonteCarloNode Expand(MonteCarloNode node)
        {
            // Select random action
            List <int> actions    = node.UnexpandedActions();
            int        action     = this.game.RandomAction(actions);
            GameState  childState = this.game.NextState(node.state, action);

            bool[]         childActions = this.game.LegalActions(childState);
            MonteCarloNode childNode    = node.Expand(action, childState, childActions);

            this.nodes[childState.GetId()] = childNode;
            return(childNode);
        }
예제 #8
0
        public List <int> GetActionsWithoutMove(MonteCarloNode node)
        {
            List <int> allActions = new List <int>();

            foreach (MonteCarloNode child in node.children.Values)
            {
                // eliminate unvisited actions and move action
                if (child != null && child.action != playerHelper.MOVE_INDEX)
                {
                    allActions.Add(child.action);
                }
            }

            return(allActions);
        }
예제 #9
0
        public MonteCarloNode(MonteCarloNode parent, int action, GameState state, bool[] allActions)
        {
            this.action = action;
            this.state  = state;

            // Monte Carlo stuff
            this.numberOfPlays = 0;
            this.numberOfWins  = 0;

            // Tree stuff
            this.parent   = parent;
            this.children = new Dictionary <int, MonteCarloNode>();

            for (int i = 0; i < allActions.Length; i++)
            {
                if (allActions[i])
                {
                    this.children[i] = null;
                }
            }
        }
예제 #10
0
        public void RunSearch()
        {
            GameState state = game.GetInitialState();

            this.MakeNode(state);
            int totalSims = 0;

            // Run until time runs out
            while (totalSims <= maxSimulation)
            {
                MonteCarloNode node   = this.Select(state);
                PlayerKind     winner = this.game.Winner(node.state);

                if (node.IsLeaf() == false && winner == PlayerKind.NONE)
                {
                    node   = this.Expand(node);
                    winner = this.Simulate(node, maxDepth);
                }
                this.Backpropagate(node, winner);

                totalSims++;
            }
        }
예제 #11
0
 /*
  * Get all legal actions except move action from root node.
  */
 public List <int> GetActionsWithoutMove(MonteCarloNode node)
 {
     return(this.game.GetActionsWithoutMove(node));
 }
예제 #12
0
        public float CalculateMCTSActionReward(float[] vectorAction)
        {
            float          mctsReward  = 0;
            MonteCarloNode rootNode    = mcts.GetRootNode();
            List <int>     mctsActions = mcts.GetActionsWithoutMove(rootNode);
            ISet <int>     annActions  = mcts.ConvertMCTSActions(vectorAction);

            double MCTSBestUCB = Double.NegativeInfinity;
            double ANNBestUCB  = Double.NegativeInfinity;
            double UCBMin      = Double.PositiveInfinity;
            double UCB         = 0;

            // Find best UCB for MCTS and ANN actions
            foreach (int action in mctsActions)
            {
                MonteCarloNode childNode = rootNode.ChildNode(action);
                if (childNode != null)
                {
                    UCB = childNode.GetUCB1(UCB1ExploreParam);
                    // Set MCTS action max UCB
                    if (UCB > MCTSBestUCB)
                    {
                        MCTSBestUCB = UCB;
                    }
                    // Set ANN action max UCB
                    if (annActions.Contains(action) && UCB > ANNBestUCB)
                    {
                        ANNBestUCB = UCB;
                    }
                    // Set min UCB
                    if (UCB < UCBMin)
                    {
                        UCBMin = UCB;
                    }
                }
            }

            // No reward will be given if suitable action not found
            // Move actions eliminated here
            if (ANNBestUCB != Double.NegativeInfinity)
            {
                if (ANNBestUCB == MCTSBestUCB)
                {
                    mctsReward = 1;
                }
                else
                {
                    mctsReward = -1;
                }

                /*
                 * // Prevent divide by zero assign too little values
                 * UCBMin = UCBMin == MCTSBestUCB ? 0 : UCBMin;
                 * MCTSBestUCB = MCTSBestUCB == 0 ? 000000000.1d : MCTSBestUCB;
                 * ANNBestUCB = ANNBestUCB == 0 ? 000000000.1d : ANNBestUCB;
                 * // Normalize the ANN UCB [0,1] -> (currentValue - minValue) / (maxValue - minValue)
                 * double normalizedANNRate = (ANNBestUCB - UCBMin) / (MCTSBestUCB - UCBMin);
                 * double differenceFromMax = 1 - normalizedANNRate;
                 * double diffSquare = Math.Pow(differenceFromMax, 2);
                 * mctsReward = (float)(1.3d * Math.Exp(-5.84d * diffSquare) - 0.01d);
                 */
            }
            else if (mctsActions.Count > 0)
            {
                // Give negative reward for non move actions that mcts does not recommend
                mctsReward = -1f;
            }
            if (name.Equals(BattleArena.RED_AGENT_TAG) && mctsReward != 0)
            {
                Debug.Log(name + " " + mctsReward + " reward given: vectorActions=" + vectorAction[0] + "," + vectorAction[1] + "," + vectorAction[2] +
                          " convertedActions=" + String.Join(",", annActions) + " mctsActions=" + String.Join(",", mctsActions));
            }
            return(mctsReward);
        }