Esempio n. 1
0
        /// <summary>
        /// Constructor to create a store of specified maximum size.
        /// </summary>
        /// <param name="maxNodes"></param>
        /// <param name="priorMoves"></param>
        public MCTSNodeStore(int maxNodes, PositionWithHistory priorMoves = null)
        {
            if (priorMoves == null)
            {
                priorMoves = PositionWithHistory.StartPosition;
            }

            MaxNodes = maxNodes;
            int allocNodes = maxNodes;

            Nodes = new MCTSNodeStructStorage(allocNodes, null,
                                              MCTSParamsFixed.STORAGE_USE_INCREMENTAL_ALLOC,
                                              MCTSParamsFixed.STORAGE_LARGE_PAGES,
                                              MCTSParamsFixed.STORAGE_USE_EXISTING_SHARED_MEM);

            long reserveChildren = maxNodes * (long)AVG_CHILDREN_PER_NODE;

            Children = new MCTSNodeStructChildStorage(this, reserveChildren);

            // Save a copy of the prior moves
            Nodes.PriorMoves = new PositionWithHistory(priorMoves);

            CeresEnvironment.LogInfo("NodeStore", "Init", $"MCTSNodeStore created with max {maxNodes} nodes, max {reserveChildren} children");

            MCTSNodeStruct.ValidateMCTSNodeStruct();
            RootIndex = new MCTSNodeStructIndex(1);
        }
        /// <summary>
        /// Makes node within the tree the root child, reorganizing the nodes and child arrays.
        ///
        /// Critically, we do this operation in situ to avoid having to transiently allocate
        /// extremely large memory objects.
        ///
        /// The operation consists of 3 stages:
        ///   - traverse the subtree breadth-first starting at the new root,
        ///     building a bitmap of which nodes are children.
        ///
        ///   - traverse the node array sequentially, processing each node that is a member of this bitmap by
        ///     moving it into the new position in the store (including modifying associated child and parent references)
        ///     Also we build a table of all the new children that need to be moved.
        ///
        ///   - using the child table built above, shift all children down. Because nodes may have children written
        ///     out of order, we don't know for sure there is enough space available. Therefore we sort the table
        ///     based on their new position in the table and then shift them down, insuring we don't overwrite
        ///     children yet to be shifted.
        ///
        /// Additionally we may have to recreate the transposition roots dictionary because
        /// the set of roots is smaller (the retined subtree only) and the node indices will change.
        /// </summary>
        /// <param name="store"></param>
        /// <param name="newRootChild"></param>
        /// <param name="newPriorMoves"></param>
        /// <param name="transpositionRoots"></param>
        public static void MakeChildNewRoot(MCTSNodeStore store, float policySoftmax,
                                            ref MCTSNodeStruct newRootChild,
                                            PositionWithHistory newPriorMoves,
                                            PositionEvalCache cacheNonRetainedNodes,
                                            TranspositionRootsDict transpositionRoots)
        {
#if DEBUG
            store.Validate();
#endif

            COUNT++;

            // Nothing to do if the requested node is already currently the root
            if (newRootChild.Index == store.RootNode.Index)
            {
                // Nothing changing in the tree, just flush the cache references
                store.ClearAllCacheIndices();
            }
            else
            {
                DoMakeChildNewRoot(store, policySoftmax, ref newRootChild, newPriorMoves, cacheNonRetainedNodes, transpositionRoots);
            }

#if DEBUG
            store.Validate();
#endif
        }
Esempio n. 3
0
 /// <summary>
 /// Constructor.
 /// </summary>
 /// <param name="node"></param>
 public TreePlot(MCTSNodeStruct node)
 {
     rawRoot = node;
     //Stopwatch sw = Stopwatch.StartNew();
     (root, treeInfo) = DrawTreeNode.Layout(node);
     //Console.WriteLine("Layout time in ms:");
     //Console.WriteLine(sw.ElapsedMilliseconds);
     canvasHeight           *= superSample;
     canvasWidth            *= superSample;
     pointRadius            *= superSample;
     edgeWidth              *= superSample;
     leftMargin             *= superSample;
     rightMargin            *= superSample;
     topMargin              *= superSample;
     bottomMargin           *= superSample;
     rightHistogramWidth    *= superSample;
     bottomHistogramHeight  *= superSample;
     horisontalSpacing      *= superSample;
     verticalSpacing        *= superSample;
     titleMargin            *= superSample;
     tickFontSize           *= superSample;
     histogramTitleFontSize *= superSample;
     plotAreaHeight          = canvasHeight - topMargin - bottomMargin - bottomHistogramHeight - verticalSpacing - titleMargin;
     plotAreaWidth           = canvasWidth - leftMargin - rightMargin - rightHistogramWidth - 2 * horisontalSpacing;
     image = new Bitmap(canvasWidth / superSample, canvasHeight / superSample);
     Plot();
 }
Esempio n. 4
0
        /// <summary>
        /// Plots the search tree for given root node and saves the image as png-file.
        /// </summary>
        /// <param name="rawNode"></param>
        public static void Save(MCTSNodeStruct rawNode, string fileName)
        {
            TreePlot treePlot = new TreePlot(rawNode);

            treePlot.Save(fileName);
            treePlot.Dispose();
        }
Esempio n. 5
0
        internal DrawTreeNode(DrawTreeNode parent, MCTSNodeStruct node, int depth, int siblingIndex, ref int identifier)
        {
            X           = -1.0f;
            Y           = (float)depth;
            id          = identifier;
            BranchIndex = parent is null ? -1 : parent.BranchIndex != -1 ? parent.BranchIndex : siblingIndex;
            identifier++;
            Children = new List <DrawTreeNode>();
            int childIndex = 0;

            // Sort children based on N so that heaviest subtree is always drawn leftmost.
            foreach (MCTSNodeStruct child in (from ind in Enumerable.Range(0, node.NumChildrenExpanded) select node.ChildAtIndexRef(ind)).OrderBy(c => - c.N))
            {
                Children.Add(new DrawTreeNode(this, child, depth + 1, childIndex, ref identifier));
                childIndex++;
            }

            Parent            = parent;
            mod               = 0.0f;
            change            = 0.0f;
            shift             = 0.0f;
            ancestor          = this;
            lmostSibling      = null;
            this.siblingIndex = siblingIndex;
            thread            = null;
        }
Esempio n. 6
0
        // --------------------------------------------------------------------------------------------
        static void ProcessNode(PositionEvalCache cache, MCTSNode node, float weightEmpirical,
                                bool saveToCache, bool rewriteNodeInTree)
        {
            Span <MCTSNodeStructChild> children = node.Ref.Children;

            // TODO: optimize this away if saveToCache is false
            ushort[] probabilities = new ushort[node.NumPolicyMoves];
            ushort[] indices       = new ushort[node.NumPolicyMoves];

            // Compute empirical visit distribution
            float[] nodeFractions = new float[node.NumPolicyMoves];
            for (int i = 0; i < node.NumChildrenExpanded; i++)
            {
                nodeFractions[i] = (float)node.ChildAtIndex(i).N / (float)node.N;
            }

            // Determine P of first unexpanded node
            // We can't allow any child to have a new P less than this
            // since we need to keep them in order by P and the resorting logic below
            // can only operate over expanded nodes
            float minP = 0;

            if (node.NumChildrenExpanded < node.NumPolicyMoves)
            {
                minP = node.ChildAtIndexInfo(node.NumChildrenExpanded).p;
            }

            // Add each move to the policy vector with blend of prior and empirical values
            for (int i = 0; i < node.NumChildrenExpanded; i++)
            {
                (MCTSNode node, EncodedMove move, FP16 p)info = node.ChildAtIndexInfo(i);
                indices[i] = (ushort)info.move.IndexNeuralNet;

                float newValue = (1.0f - weightEmpirical) * info.p
                                 + weightEmpirical * nodeFractions[i];
                if (newValue < minP)
                {
                    newValue = minP;
                }
                probabilities[i] = CompressedPolicyVector.EncodedProbability(newValue);

                if (rewriteNodeInTree && weightEmpirical != 0)
                {
                    MCTSNodeStructChild thisChild = children[i];
                    if (thisChild.IsExpanded)
                    {
                        ref MCTSNodeStruct childNodeRef = ref thisChild.ChildRef;
                        thisChild.ChildRef.P = (FP16)newValue;
                    }
                    else
                    {
                        node.Ref.ChildAtIndex(i).SetUnexpandedPolicyValues(thisChild.Move, (FP16)newValue);
                    }
                }
            }
        /// <summary>
        /// Prunes cache down to approximately specified target size.
        /// </summary>
        /// <param name="store"></param>
        /// <param name="targetSize">target numer of nodes, or -1 to use default sizing</param>
        /// <returns></returns>
        internal int Prune(MCTSNodeStore store, int targetSize)
        {
            int startNumInUse = numInUse;

            // Default target size is 70% of maximum.
            if (targetSize == -1)
            {
                targetSize = (nodes.Length * 70) / 100;
            }
            if (numInUse <= targetSize)
            {
                return(0);
            }

            lock (lockObj)
            {
                int count = 0;
                for (int i = 0; i < nodes.Length; i++)
                {
                    // TODO: the long is cast to int, could we possibly overflow? make long?
                    if (nodes[i] != null)
                    {
                        pruneSequenceNums[count++] = (int)nodes[i].LastAccessedSequenceCounter;
                    }
                }

                Span <int> slice = new Span <int>(pruneSequenceNums).Slice(0, count);

                // Compute the minimum sequence number an entry must have
                // to be retained (to enforce LRU eviction)
                //float cutoff = KthSmallestValue.CalcKthSmallestValue(keyPrioritiesForSorting, numToPrune);
                int cutoff;

                cutoff = KthSmallestValueInt.CalcKthSmallestValue(slice, numInUse - targetSize);
                //Console.WriteLine(slice.Length + " " + (numInUse-targetSize) + " --> "
                //                 + cutoff + " correct " + slice[numInUse-targetSize] + " avg " + slice[numInUse/2]);

                int maxEntries = pruneCount == 0 ? (numInUse + 1) : nodes.Length;
                for (int i = 1; i < maxEntries; i++)
                {
                    MCTSNode node = nodes[i];
                    if (node != null && node.LastAccessedSequenceCounter < cutoff)
                    {
                        MCTSNodeStructIndex nodeIndex = new MCTSNodeStructIndex(node.Index);
                        nodes[i] = null;

                        ref MCTSNodeStruct refNode = ref store.Nodes.nodes[nodeIndex.Index];
                        refNode.CacheIndex = 0;

                        numInUse--;
                    }
                }
                pruneCount++;
            }
Esempio n. 8
0
        /// <summary>
        /// Plots the search tree and opens it in image viewer.
        /// </summary>
        /// <param name="rawNode"></param>
        public static void Show(MCTSNodeStruct rawNode)
        {
            string file = Path.GetTempFileName() + ".png";

            Save(rawNode, file);
            new Process
            {
                StartInfo = new ProcessStartInfo(file)
                {
                    UseShellExecute = true
                }
            }.Start();
        }
Esempio n. 9
0
        protected override LeafEvaluationResult DoTryEvaluate(MCTSNode node)
        {
            VerifyCompatibleNetworkDefinition(node);

            if (OtherContext.Tree.TranspositionRoots.TryGetValue(node.Ref.ZobristHash, out int nodeIndex))
            {
                using (new SearchContextExecutionBlock(OtherContext))
                {
                    ref MCTSNodeStruct       otherNodeRef = ref OtherContext.Tree.Store.Nodes.nodes[nodeIndex];
                    CompressedPolicyVector[] cpvArray     = new CompressedPolicyVector[1];

                    if (otherNodeRef.Terminal != Chess.GameResult.Unknown)
                    {
                        NumMisses++;
                        return(default);
Esempio n. 10
0
        /// <summary>
        /// Calculate tree layout.
        /// </summary>
        internal static (DrawTreeNode, DrawTreeInfo) Layout(MCTSNodeStruct root)
        {
            DrawTreeInfo treeInfo = new DrawTreeInfo();
            int          id       = 0;
            DrawTreeNode drawRoot = new DrawTreeNode(null, root, 0, 0, ref id);

            drawRoot.FirstWalk();
            float      min           = drawRoot.SecondWalk(0.0f, float.MaxValue);
            float      maxX          = float.MinValue;
            float      maxDepth      = float.MinValue;
            List <int> nodesPerDepth = new List <int>();

            // Shift whole tree so that min x is 0.
            drawRoot.ThirdWalk(-min, treeInfo);
            return(drawRoot, treeInfo);
        }
Esempio n. 11
0
        /// <summary>
        /// Returns the node which is the root of the cluster (possibly same as node).
        /// </summary>
        /// <param name="node"></param>
        /// <returns></returns>
        public static int CheckAddToCluster(MCTSNode node)
        {
            int rootIndex;

            if (!node.Context.Tree.TranspositionRoots.TryGetValue(node.Annotation.PositionHashForCaching, out rootIndex))
            {
                throw new Exception("Internal error");
#if NOT
                // We are the new root, just add to roots table and exit
                bool added = node.Context.TranspositionRoots.TryAdd(node.Annotation.PositionHashForCaching, node.Index);

                // If we failed to add, this means this node was already added in the interim
                // Therefore recursively call ourself so that we can get our self added to the end of the list
                if (!added)
                {
                    return(CheckAddToCluster(node));
                }
                else
                {
                    return(node.Index);
                }
#endif
            }
            else
            {
                // Cluster already exists. Apppend ourself
                ref MCTSNodeStruct traverseRef = ref node.Context.Tree.Store.Nodes.nodes[rootIndex];
                while (true)
                {
                    if (traverseRef.NextTranspositionLinked == 0)
                    {
                        break;
                    }

                    traverseRef = ref node.Context.Tree.Store.Nodes.nodes[traverseRef.NextTranspositionLinked];
                }

                // Tack ourself onto the end
                // TODO: could we more efficiently put ourself at beginning?
                // TODO: concurrency?
                if (traverseRef.Index.Index != node.Index)
                {
                    traverseRef.NextTranspositionLinked = node.Index;
                }

                return(rootIndex);
            }
Esempio n. 12
0
        private static LeafEvaluationResult ExtractTranspositionNodesFromSubtree(MCTSNode node, ref MCTSNodeStruct transpositionRootNode,
                                                                                 ref int numAlreadyLinked, MCTSNodeTranspositionVisitor linkedVisitor)
        {
            LeafEvaluationResult result = default;

            // Determine how many evaluations we should extract (based on number requested and number available)
            int numAvailable = linkedVisitor.TranspositionRootNWhenVisitsStarted - numAlreadyLinked;
            int numDesired   = node.NInFlight + node.NInFlight2;

            if (numDesired > numAvailable && WARN_COUNT < 10)
            {
                Console.WriteLine(numDesired + " Warning: multiple nodes were requested from the transposition subtree, available " + numAvailable);
                WARN_COUNT++;
            }
            int numToFetch = Math.Min(numDesired, numAvailable);

            Debug.Assert(numToFetch > 0);

            // Extract each evaluation
            for (int i = 0; i < numToFetch; i++)
            {
                MCTSNodeStructIndex transpositionSubnodeIndex = linkedVisitor.Visitor.GetNext();
                Debug.Assert(!transpositionSubnodeIndex.IsNull);

                NumExtractedAndNeverCloned++;
                numAlreadyLinked++;

                // Prepare the result to return
                ref MCTSNodeStruct   transpositionSubnode = ref node.Context.Tree.Store.Nodes.nodes[transpositionSubnodeIndex.Index];
                LeafEvaluationResult thisResult           = new LeafEvaluationResult(transpositionSubnode.Terminal, transpositionRootNode.WinP,
                                                                                     transpositionRootNode.LossP, transpositionRootNode.MPosition);

                // Update our result node to include this node
                result = AddResultToResults(result, numToFetch, i, thisResult);

                if (VERBOSE)
                {
                    Console.WriteLine($"ProcessAlreadyLinked {node.Index} yields {result.WinP} {result.LossP} via linked subnode root {transpositionRootNode.Index.Index} {transpositionRootNode} chose {transpositionSubnode.Index.Index}");
                }

                node.Ref.NumNodesTranspositionExtracted++;
            }
Esempio n. 13
0
        public static void DumpAllNodes(MCTSIterator context, ref MCTSNodeStruct node,
                                        Base.DataType.Trees.TreeTraversalType type = Base.DataType.Trees.TreeTraversalType.BreadthFirst,
                                        bool childDetail = false)
        {
            int index = 1;

            // Visit all nodes and verify various conditions are true
            node.Traverse(context.Tree.Store,
                          (ref MCTSNodeStruct node) =>
            {
                Console.WriteLine(index + " " + node);
                if (childDetail)
                {
                    int childIndex = 0;
                    foreach (MCTSNodeStructChild childInfo in node.Children)
                    {
                        Console.WriteLine($"  {childIndex++,3} {childInfo}");
                    }
                }
                index++;
                return(true);
            }, type);
        }
        static void DoMakeChildNewRoot(MCTSNodeStore store, float policySoftmax, ref MCTSNodeStruct newRootChild,
                                       PositionWithHistory newPriorMoves,
                                       PositionEvalCache cacheNonRetainedNodes,
                                       TranspositionRootsDict transpositionRoots)
        {
            ChildStartIndexToNodeIndex[] childrenToNodes;

            uint     numNodesUsed;
            uint     numChildrenUsed;
            BitArray includedNodes;

            int newRootChildIndex   = newRootChild.Index.Index;
            int newIndexOfNewParent = -1;

            int nextAvailableNodeIndex = 1;

            // Traverse this subtree, building a bit array of visited nodes
            includedNodes = MCTSNodeStructUtils.BitArrayNodesInSubtree(store, ref newRootChild, out numNodesUsed);

            //using (new TimingBlock("Build position cache "))
            if (cacheNonRetainedNodes != null)
            {
                long estNumNodes = store.RootNode.N - numNodesUsed;
                cacheNonRetainedNodes.InitializeWithSize((int)estNumNodes);
                ExtractPositionCacheNonRetainedNodes(store, policySoftmax, includedNodes, in newRootChild, cacheNonRetainedNodes);
            }

            // We will constract a table indicating the starting index and length of
            // children associated with the nodes we are extracting
            childrenToNodes = GC.AllocateUninitializedArray <ChildStartIndexToNodeIndex>((int)numNodesUsed);

            void RewriteNodes()
            {
                // TODO: Consider that the above is possibly all we need to do in some case
                //       Suppose the subtree is very large relative to the whole
                //       This approach would be much faster, and orphan an only small part of the storage

                // Now scan all above nodes.
                // If they don't belong, ignore.
                // If they do belong, swap them down to the next available lower location
                // Note that this can't be parallelized, since we have to do it strictly in order of node index
                int numRewrittenNodesDone = 0;

                for (int i = 2; i < store.Nodes.nextFreeIndex; i++)
                {
                    if (includedNodes.Get(i))
                    {
                        ref MCTSNodeStruct thisNode = ref store.Nodes.nodes[i];

                        // Reset any cache entry
                        thisNode.CacheIndex = 0;

                        // Not possible to support transposition linked nodes,
                        // since the root may be in a part of the tree that is not retained
                        // and possibly already overwritten.
                        // We expect them to have already been materialized by the time we reach this point.
                        Debug.Assert(!thisNode.IsTranspositionLinked);
                        Debug.Assert(thisNode.NumNodesTranspositionExtracted == 0);

                        // Remember this location if this is the new parent
                        if (i == newRootChildIndex)
                        {
                            newIndexOfNewParent = nextAvailableNodeIndex;
                        }

                        // Move the actual node
                        MoveNodePosition(store, new MCTSNodeStructIndex(i), new MCTSNodeStructIndex(nextAvailableNodeIndex));

                        // Reset all transposition information
                        thisNode.NextTranspositionLinked = 0;

                        childrenToNodes[numRewrittenNodesDone] = new ChildStartIndexToNodeIndex(thisNode.childStartBlockIndex, nextAvailableNodeIndex, thisNode.NumPolicyMoves);

                        // Re-insert this into the transpositionRoots (with the updated node index)
                        if (transpositionRoots != null)
                        {
                            transpositionRoots.TryAdd(thisNode.ZobristHash, nextAvailableNodeIndex);
                        }

                        Debug.Assert(thisNode.NumNodesTranspositionExtracted == 0);

                        numRewrittenNodesDone++;
                        nextAvailableNodeIndex++;
                    }
                }
            }