Exemplo n.º 1
0
        /// <summary>
        /// The rollout step of the MCTS.
        /// </summary>
        /// <param name="expandedNode">the expanded node the simulation is performed on</param>
        /// <returns>the leaf node of the simulation</returns>
        protected MctsNode Playout(MctsNode expandedNode)
        {
            // select child for simulation
            MctsNode child;

            if (expandedNode.Children.Count < 1)
            {
                return(expandedNode);
            }

            MctsNode simulationChild = child = expandedNode
                                               .Children.OrderByDescending(c => c.Score).First();

            // simulate game
            int i = 0;

            while (i++ < _mctsParameters.RolloutDepth &&
                   (!simulationChild.IsEndTurn && simulationChild.IsRunning))
            {
                expand(simulationChild);
                // greedy
                simulationChild = simulationChild
                                  .Children.OrderByDescending(c => c.Score).First();
            }

            if (child.Children?.Any() ?? false)
            {
                child.Children = null;
            }

            return(simulationChild);
        }
Exemplo n.º 2
0
        /// <summary>
        /// The selection step of the MCTS. Returns the leaf node with the currently best uct score.
        /// </summary>
        /// <param name="root">the root node the MCTS is performed on</param>
        /// <returns>the leaf node with the best uct score</returns>
        protected MctsNode Select(MctsNode root)
        {
            MctsNode selectedNode = root;

            while (!selectedNode.IsLeaf)
            {
                selectedNode = selectedNode
                               .Children.OrderByDescending(c => CalculateUct(c)).First();
            }
            return(selectedNode);
        }
Exemplo n.º 3
0
        public int ParentCount()
        {
            int      count = 0;
            MctsNode child = this;

            while (child.Parent != null)
            {
                child = child.Parent;
                count++;
            }
            return(count);
        }
Exemplo n.º 4
0
        /// <summary>
        /// Calculates the uct score for a given node.
        /// </summary>
        /// <param name="node">the node the uct is calculated for</param>
        /// <returns>the uct score </returns>
        private double CalculateUct(MctsNode node)
        {
            // value form the heuristic
            double heuristic = node.TotalScore;;
            // the exploration parameter — theoretically equal to √2
            double constant = _mctsParameters.UCTConstant;             // Math.sqrt(2);
            // the actual ucb value
            double uct = constant * Math.Sqrt(2 *
                                              (Math.Log(node.Parent.VisitCount) / node.VisitCount));

            return(heuristic + uct);
        }
Exemplo n.º 5
0
        public override MctsNode Simulate(POGame game)
        {
            POGame gameCopy = game.getCopy();

            // initials root node
            var initLeafs = new List <MctsNode>();
            var root      = new MctsNode(_playerId, new List <MctsNode.ScoreExt> {
                new  MctsNode.ScoreExt(1.0, _scoring)
            }, gameCopy, null, null);

            // simulate
            MctsNode bestNode = Simulate(_deltaTime, root, ref initLeafs);

            return(bestNode);
        }
Exemplo n.º 6
0
        /// <summary>
        /// The backpropagation step of MCTS.
        /// </summary>
        /// <param name="leafNode">the leaf node the rollout has calculated</param>
        protected void backpropagate(MctsNode leafNode, double score)
        {
            MctsNode parent = leafNode;

            while (parent != null)
            {
                parent.VisitCount++;
                //parent.TotalScore += score;
                // only highest score should be backpropagated
                if (parent.TotalScore < score)
                {
                    parent.TotalScore = score;
                }

                parent = parent.Parent;
            }
        }
Exemplo n.º 7
0
        public List <PlayerTask> GetSolution()
        {
            var      solutions = new List <PlayerTask>();
            MctsNode bestChild = this;

            while (!bestChild.IsEndTurn && !bestChild.IsLeaf)
            {
                solutions.Add(bestChild.Task);
                bestChild = bestChild.Children.OrderByDescending(c => c.TotalScore).First();
            }

            if (bestChild.IsEndTurn || bestChild.IsLeaf)
            {
                solutions.Add(bestChild.Task);
                return(solutions);
            }

            return(solutions);
        }
Exemplo n.º 8
0
        /// <summary>
        /// Performs the MCTS simulation for to given root node and returns the best child.
        /// </summary>
        /// <param name="simulationTime"></param>
        /// <param name="root">the root node the MCTS is performed on</param>
        /// <param name="leafNodes">all leaf nodes of the MCTS simulation</param>
        /// <returns>the root's best child node</returns>
        protected MctsNode Simulate(double simulationTime, MctsNode root, ref List <MctsNode> leafNodes)
        {
            // simulate
            while (Watch.Elapsed.TotalMilliseconds <= simulationTime)
            {
                // SELECTION using UCT as Tree-Policy
                MctsNode selectedNode = Select(root);

                MctsNode leafNode = selectedNode;
                if (!leafNode.IsEndTurn)
                {
                    // EXPANDATION
                    expand(selectedNode);
                    // SIMULATION using greedy as Default Policy
                    leafNode = Playout(selectedNode);
                }

                // BACKPROPAGATE
                backpropagate(leafNode, leafNode.Score);


                if (selectedNode.IsEndTurn && !leafNodes.Contains(selectedNode))
                {
                    leafNodes.Add(selectedNode);
                }
            }

            if (root.Children == null)
            {
                expand(root);
            }

            if (root.Children.Count < 1)
            {
                return(root);
            }

            return(root.Children
                   .OrderByDescending(c => c.TotalScore)
                   .First());
        }
Exemplo n.º 9
0
        public MctsNode(int playerId, List <ScoreExt> scorings, POGame game, PlayerTask task, MctsNode parent)
        {
            _parent   = parent;
            _scorings = scorings;
            _playerId = playerId;
            _game     = game.getCopy();
            _task     = task;

            VisitCount = 1;

            if (Task != null)
            {
                Dictionary <PlayerTask, POGame> dir = Game.Simulate(new List <PlayerTask> {
                    Task
                });
                POGame newGame = dir[Task];

                Game = newGame;
                // simulation has failed, maybe reduce score?
                if (Game == null)
                {
                    _endTurn = 1;
                }
                else
                {
                    _gameState = Game.State == SabberStoneCore.Enums.State.RUNNING ? 0
                                                : (PlayerController.PlayState == PlayState.WON ? 1 : -1);
                    _endTurn = Game.CurrentPlayer.Id != _playerId ? 1 : 0;

                    foreach (ScoreExt scoring in Scorings)
                    {
                        scoring.Controller = PlayerController;
                        _score            += scoring.Value * scoring.Rate();
                    }
                    _score     /= Scorings.Count;
                    TotalScore += _score;
                }
            }
        }
Exemplo n.º 10
0
 public MctsNode(POGame game, PlayerTask task, MctsNode parent)
     : this(parent.PlayerId, parent.Scorings, game, task, parent)
 {
 }
        public override MctsNode Simulate(POGame game)
        {
            Console.WriteLine("Current win rate is " + 0);
            POGame gameCopy = game.getCopy();

            // initials root node
            var children = new List <MctsNode>();
            var parent   = new MctsNode(_playerId, new List <MctsNode.ScoreExt> {
                new MctsNode.ScoreExt(1, _scoring)
            }, gameCopy, null, null);

            // simulate
            MctsNode bestNode = Simulate(_deltaTime, parent, ref children);

            // initials opponent's history
            Initialize(gameCopy);

            var simulationQueue = new Queue <KeyValuePair <POGame, List <MctsNode> > >();

            simulationQueue.Enqueue(new KeyValuePair <POGame, List <MctsNode> >(gameCopy, children));

            int i = 0;

            while (i < _predictionParameters.SimulationDepth &&
                   simulationQueue.Count > 0)
            {
                // calculate the lower and upper time bound of the current depth
                double lowerSimulationTimeBound = _deltaTime + i * (2 * _deltaTime);

                KeyValuePair <POGame, List <MctsNode> > simulation = simulationQueue.Dequeue();
                List <MctsNode> leafs = simulation.Value;

                leafs = leafs.Where(l => l.Game != null)
                        .OrderByDescending(l => l.Score)
                        .Take(leafs.Count > _predictionParameters.LeafCount
                                                ? _predictionParameters.LeafCount : leafs.Count)
                        .ToList();
                if (leafs.Count < 0)
                {
                    return(bestNode);
                }

                Controller        opponent       = GetOpponent(simulation.Key);
                List <Prediction> predicitionMap = GetPredictionMap(simulation.Key, opponent);
                var oldSimulations = new Dictionary <POGame, List <MctsNode> >();

                // the simulation time for one leaf
                double timePerLeaf = (2 * _deltaTime) / leafs.Count;

                // get all games from all leaf nodes
                for (int j = 0; j < leafs.Count; j++)
                {
                    // calculate the lower time bound of the current leaf
                    double lowerLeafTimeBound = lowerSimulationTimeBound + j * timePerLeaf;

                    MctsNode leafNode = leafs[j];
                    POGame   oppGame  = leafNode.Game;
                    double   leafScore;
                    // XXX: game can be null

                    leafScore = SimulateOpponentWithPrediction(lowerLeafTimeBound, timePerLeaf,
                                                               oppGame, opponent, predicitionMap, ref oldSimulations);
                    // back-propagate score
                    backpropagate(leafNode, leafScore);
                }



                // add new simulations
                foreach (KeyValuePair <POGame, List <MctsNode> > sim in oldSimulations)
                {
                    simulationQueue.Enqueue(sim);
                }
                i++;
            }
            MctsNode resultnode = parent.Children
                                  .OrderByDescending(c => c.TotalScore)
                                  .First();

            Console.WriteLine("Current win rate is " + resultnode.TotalScore);
            return(resultnode);
        }
        private double SimulateOpponentWithPrediction(double lowerTimeBound, double timePerLeaf, POGame oppGame, Controller opponent,
                                                      IReadOnlyList <Prediction> predicitionMap, ref Dictionary <POGame, List <MctsNode> > newSimulations)
        {
            double predictionScore = 0;

            if (!(predicitionMap?.Any() ?? false))
            {
                return(predictionScore);
            }
            int denominator = predicitionMap.Count;
            var scorings    = predicitionMap.GroupBy(p => p.Deck.Scoring)
                              .Select(c => new MctsNode.ScoreExt(((double)c.Count() / denominator), c.Key))
                              .OrderByDescending(s => s.Value).ToList();

            // the simulation time for one prediction
            double timePerPrediction = timePerLeaf / predicitionMap.Count;

            // use prediction for each game
            for (int i = 0; i < predicitionMap.Count; i++)
            {
                Prediction   prediction = predicitionMap[i];
                SetasideZone setasideZone;
                setasideZone = new SetasideZone(opponent);

                // create deck zone
                List <Card> deckCards = prediction.Deck.Cards;
                DeckZone    deckZone;
                deckZone = new DeckZone(opponent);
                createZone(opponent, deckCards, deckZone, ref setasideZone);
                deckZone.Shuffle();

                // create hand zone
                List <Card> handCards = prediction.Hand.Cards;
                HandZone    handZone;
                handZone = new HandZone(opponent);
                createZone(opponent, handCards, handZone, ref setasideZone);

                var oppLeafNodes = new List <MctsNode>();

                // forward game
                POGame forwardGame = oppGame.getCopy();

                // upper time bound for simulation the opponent using the current prediction
                double oppSimulationTime = lowerTimeBound + (i + 1) * timePerPrediction / 2;

                // simulate opponent's moves
                while (forwardGame != null &&
                       forwardGame.State == SabberStoneCore.Enums.State.RUNNING &&
                       forwardGame.CurrentPlayer.Id == opponent.Id)
                {
                    // simulate
                    var      oppRoot     = new MctsNode(opponent.Id, scorings, forwardGame, null, null);
                    MctsNode bestOppNode = Simulate(oppSimulationTime, oppRoot, ref oppLeafNodes);
                    // get solution
                    List <PlayerTask> solutions = bestOppNode.GetSolution();
                    for (int j = 0; j < solutions.Count && (forwardGame != null); j++)
                    {
                        PlayerTask oppTask = solutions[j];
                        Dictionary <PlayerTask, POGame> dir = forwardGame.Simulate(new List <PlayerTask> {
                            oppTask
                        });
                        forwardGame = dir[oppTask];

                        if (forwardGame != null && forwardGame.CurrentPlayer.Choice != null)
                        {
                            break;
                        }
                    }
                }

                // upper time bound for simulation the player using the forwarded game
                double simulationTime = oppSimulationTime + timePerPrediction / 2;
                double score          = 0;
                var    leafs          = new List <MctsNode>();

                // simulate player using forwarded opponent game
                while (forwardGame != null &&
                       forwardGame.State == SabberStoneCore.Enums.State.RUNNING &&
                       forwardGame.CurrentPlayer.Id == _playerId)
                {
                    // simulate
                    var root = new MctsNode(_playerId, new List <MctsNode.ScoreExt> {
                        new MctsNode.ScoreExt(1.0, _scoring)
                    }, forwardGame, null, null);
                    MctsNode bestNode = Simulate(simulationTime, root, ref leafs);
                    // get solution
                    List <PlayerTask> solutions = bestNode.GetSolution();
                    for (int j = 0; j < solutions.Count && (forwardGame != null); j++)
                    {
                        PlayerTask task = solutions[j];
                        Dictionary <PlayerTask, POGame> dir = forwardGame.Simulate(new List <PlayerTask> {
                            task
                        });
                        forwardGame = dir[task];

                        if (forwardGame != null && forwardGame.CurrentPlayer.Choice != null)
                        {
                            break;
                        }
                    }
                    // TODO: maybe penalty forwardGame == null
                    score = bestNode.TotalScore;
                }
                predictionScore += score;

                if (forwardGame != null)
                {
                    newSimulations.Add(forwardGame, leafs);
                }
            }
            return(predictionScore);
        }
Exemplo n.º 13
0
 /// <summary>
 /// The expanditation step of the MCTS which adds all children to the selected leaf node.
 /// </summary>
 /// <param name="selectedNode">the selected leaf node</param>
 protected void expand(MctsNode selectedNode)
 {
     // TODO: add something like a expansion-threshold
     selectedNode.Children = selectedNode
                             .Tasks.Select(t => new MctsNode(selectedNode.Game, t, selectedNode)).ToList();
 }