Example #1
0
        private IEnumerable <PVNetworkBasedMCTreeSearchNode <TState, TAction> > Selfplay()
        {
            List <PVNetworkBasedMCTreeSearchNode <TState, TAction> > result = new List <PVNetworkBasedMCTreeSearchNode <TState, TAction> >();

            //Parallel.ForEach(Enumerable.Range(0, NumberOfGames), number =>
            //{
            //    var mcts = new PVNetworkBasedMCTreeSearch<TGame, TState, TAction, TPlayer>(Game, RootState, Network, Random);
            //
            //    while (mcts.CurrentNode.GameState.IsFinal == false)
            //    {
            //        mcts.Round(NumberOfPlayoutsPerMove);
            //        var bestChild = mcts.CurrentNode.GetMostVisitedChild();
            //        mcts.Play(bestChild.LastAction);
            //    }
            //
            //    result.Add(mcts.CurrentNode);
            //});



            for (int gameNumber = 0; gameNumber < NumberOfGames; gameNumber++)
            {
                Console.Title = $"Selfplay {gameNumber + 1}/{NumberOfGames}";
                var mcts = new PVNetworkBasedMCTreeSearch <TGame, TState, TAction, TPlayer>(Expander);
                //var actions = Expander.Game.GetAllowedActions(RootState);
                var networkOutput = Network.Predict(RootState);
                var node          = PVNetworkBasedMCTreeSearchNode <TState, TAction> .CreateRoot(RootState, networkOutput);

                while (node.State.IsFinal == false)
                {
                    for (int i = 0; i < NumberOfPlayoutsPerMove; i++)
                    {
                        MCTreeSearchRound <PVNetworkBasedMCTreeSearchNode <TState, TAction>, TState, TAction> round = mcts.RoundWithDetails(node);
                        var ticTacToeRound = round as MCTreeSearchRound <PVNetworkBasedMCTreeSearchNode <GameState, GameAction>, GameState, GameAction>;

                        if (ticTacToeRound != null)
                        {
                            ticTacToeRound.WriteTo(Console.Out);
                            Console.ReadLine();
                        }
                    }

                    //ConsoleUtility.WriteLine(node);
                    node = node.GetMostVisitedChild();
                }

                result.Add(node);
            }

            return(result);
        }
Example #2
0
        public static void WriteLine <TState, TAction>(PVNetworkBasedMCTreeSearchNode <TState, TAction> node, double[] trainingOutput)
            where TState : IPeriodState
        {
            WriteLine(node);

            for (int y = 0; y < 3; y++)
            {
                for (int x = 0; x < 3; x++)
                {
                    Console.Write($"{trainingOutput[x + y * 3]:f3}  ");
                }

                Console.WriteLine();
            }

            Console.WriteLine($"{trainingOutput[9]:f3}");
        }
Example #3
0
        public void Train(IEnumerable <PVNetworkBasedMCTreeSearchNode <TState, TAction> > finalNodes, int epoches)
        {
            List <double[]> inputs  = new List <double[]>();
            List <double[]> outputs = new List <double[]>();

            foreach (var finalNode in finalNodes)
            {
                PVNetworkBasedMCTreeSearchNode <TState, TAction> node = finalNode;

                while (node != null)
                {
                    double[] trainingOutput = GetTrainingOutput(node, finalNode);
                    ConsoleUtility.WriteLine(node, trainingOutput);
                    Console.ReadLine();
                    inputs.Add(TransformInput(node.State));
                    outputs.Add(trainingOutput);
                    node = node.Parent;
                }
            }

            Model.Train(inputs.ToArray(), outputs.ToArray(), epoches);
        }
Example #4
0
        public static void WriteLine <TState, TAction>(PVNetworkBasedMCTreeSearchNode <TState, TAction> node)
            where TState : IPeriodState
        {
            var tnode = node as PVNetworkBasedMCTreeSearchNode <GameState, GameAction>;

            if (tnode != null)
            {
                Console.WriteLine(tnode.State.CurrentPlayer);
                Console.WriteLine(tnode.State);

                for (ushort y = 0; y < 3; y++)
                {
                    for (ushort x = 0; x < 3; x++)
                    {
                        PVNetworkBasedMCTreeSearchNode <GameState, GameAction> c = null;
                        tnode.Children?.TryGetValue(new GameAction(x, y), out c);
                        uint visited = c?.Visits ?? 0;
                        Console.Write($"{visited,5}  ");
                    }

                    Console.WriteLine();
                }

                for (ushort y = 0; y < 3; y++)
                {
                    for (ushort x = 0; x < 3; x++)
                    {
                        PVNetworkBasedMCTreeSearchNode <GameState, GameAction> c = null;
                        tnode.Children?.TryGetValue(new GameAction(x, y), out c);
                        double totalReward = c?.Value ?? 0;
                        Console.Write($"{totalReward,5:f1}  ");
                    }

                    Console.WriteLine();
                }
            }
        }
Example #5
0
        protected override double[] GetTrainingOutput(PVNetworkBasedMCTreeSearchNode <GameState, GameAction> node, PVNetworkBasedMCTreeSearchNode <GameState, GameAction> finalNode)
        {
            double[] result = new double[TicTacToePVNetworkOutput.OutputSize];

            for (int i = 0; i < TicTacToePVNetworkOutput.OutputSize; i++)
            {
                result[i] = 0;
            }

            if (node.Children != null)
            {
                foreach (var child in node.Children)
                {
                    //double pn = node.NetworkOutput.GetProbability(child.Key);
                    //double pm = (double)child.Value.Visited / node.Visited;
                    //result[child.Key.X + child.Key.Y * 3] = (pn + pm) / 2.0;

                    double pn = node.NetworkOutput.GetProbability(child.Key);

                    if (child.Value.Visits == 0)
                    {
                        result[child.Key.X + child.Key.Y * 3] = pn;
                    }
                    else
                    {
                        double pm = child.Value.Value / child.Value.Visits;
                        result[child.Key.X + child.Key.Y * 3] = (pn + pm) / 2.0;
                    }
                }
            }

            Player winner = finalNode.State.GetWinner();

            result[9] = winner == null ? 0.5 : winner == node.State.CurrentPlayer ? 1 : 0;
            return(result);
        }
Example #6
0
 protected abstract double[] GetTrainingOutput(PVNetworkBasedMCTreeSearchNode <TState, TAction> node, PVNetworkBasedMCTreeSearchNode <TState, TAction> finalNode);