/// <summary> /// Constructor to create a store of specified maximum size. /// </summary> /// <param name="maxNodes"></param> /// <param name="priorMoves"></param> public MCTSNodeStore(int maxNodes, PositionWithHistory priorMoves = null) { if (priorMoves == null) { priorMoves = PositionWithHistory.StartPosition; } MaxNodes = maxNodes; int allocNodes = maxNodes; Nodes = new MCTSNodeStructStorage(allocNodes, null, MCTSParamsFixed.STORAGE_USE_INCREMENTAL_ALLOC, MCTSParamsFixed.STORAGE_LARGE_PAGES, MCTSParamsFixed.STORAGE_USE_EXISTING_SHARED_MEM); long reserveChildren = maxNodes * (long)AVG_CHILDREN_PER_NODE; Children = new MCTSNodeStructChildStorage(this, reserveChildren); // Save a copy of the prior moves Nodes.PriorMoves = new PositionWithHistory(priorMoves); CeresEnvironment.LogInfo("NodeStore", "Init", $"MCTSNodeStore created with max {maxNodes} nodes, max {reserveChildren} children"); MCTSNodeStruct.ValidateMCTSNodeStruct(); RootIndex = new MCTSNodeStructIndex(1); }
/// <summary> /// Makes node within the tree the root child, reorganizing the nodes and child arrays. /// /// Critically, we do this operation in situ to avoid having to transiently allocate /// extremely large memory objects. /// /// The operation consists of 3 stages: /// - traverse the subtree breadth-first starting at the new root, /// building a bitmap of which nodes are children. /// /// - traverse the node array sequentially, processing each node that is a member of this bitmap by /// moving it into the new position in the store (including modifying associated child and parent references) /// Also we build a table of all the new children that need to be moved. /// /// - using the child table built above, shift all children down. Because nodes may have children written /// out of order, we don't know for sure there is enough space available. Therefore we sort the table /// based on their new position in the table and then shift them down, insuring we don't overwrite /// children yet to be shifted. /// /// Additionally we may have to recreate the transposition roots dictionary because /// the set of roots is smaller (the retined subtree only) and the node indices will change. /// </summary> /// <param name="store"></param> /// <param name="newRootChild"></param> /// <param name="newPriorMoves"></param> /// <param name="transpositionRoots"></param> public static void MakeChildNewRoot(MCTSNodeStore store, float policySoftmax, ref MCTSNodeStruct newRootChild, PositionWithHistory newPriorMoves, PositionEvalCache cacheNonRetainedNodes, TranspositionRootsDict transpositionRoots) { #if DEBUG store.Validate(); #endif COUNT++; // Nothing to do if the requested node is already currently the root if (newRootChild.Index == store.RootNode.Index) { // Nothing changing in the tree, just flush the cache references store.ClearAllCacheIndices(); } else { DoMakeChildNewRoot(store, policySoftmax, ref newRootChild, newPriorMoves, cacheNonRetainedNodes, transpositionRoots); } #if DEBUG store.Validate(); #endif }
/// <summary> /// Constructor. /// </summary> /// <param name="node"></param> public TreePlot(MCTSNodeStruct node) { rawRoot = node; //Stopwatch sw = Stopwatch.StartNew(); (root, treeInfo) = DrawTreeNode.Layout(node); //Console.WriteLine("Layout time in ms:"); //Console.WriteLine(sw.ElapsedMilliseconds); canvasHeight *= superSample; canvasWidth *= superSample; pointRadius *= superSample; edgeWidth *= superSample; leftMargin *= superSample; rightMargin *= superSample; topMargin *= superSample; bottomMargin *= superSample; rightHistogramWidth *= superSample; bottomHistogramHeight *= superSample; horisontalSpacing *= superSample; verticalSpacing *= superSample; titleMargin *= superSample; tickFontSize *= superSample; histogramTitleFontSize *= superSample; plotAreaHeight = canvasHeight - topMargin - bottomMargin - bottomHistogramHeight - verticalSpacing - titleMargin; plotAreaWidth = canvasWidth - leftMargin - rightMargin - rightHistogramWidth - 2 * horisontalSpacing; image = new Bitmap(canvasWidth / superSample, canvasHeight / superSample); Plot(); }
/// <summary> /// Plots the search tree for given root node and saves the image as png-file. /// </summary> /// <param name="rawNode"></param> public static void Save(MCTSNodeStruct rawNode, string fileName) { TreePlot treePlot = new TreePlot(rawNode); treePlot.Save(fileName); treePlot.Dispose(); }
internal DrawTreeNode(DrawTreeNode parent, MCTSNodeStruct node, int depth, int siblingIndex, ref int identifier) { X = -1.0f; Y = (float)depth; id = identifier; BranchIndex = parent is null ? -1 : parent.BranchIndex != -1 ? parent.BranchIndex : siblingIndex; identifier++; Children = new List <DrawTreeNode>(); int childIndex = 0; // Sort children based on N so that heaviest subtree is always drawn leftmost. foreach (MCTSNodeStruct child in (from ind in Enumerable.Range(0, node.NumChildrenExpanded) select node.ChildAtIndexRef(ind)).OrderBy(c => - c.N)) { Children.Add(new DrawTreeNode(this, child, depth + 1, childIndex, ref identifier)); childIndex++; } Parent = parent; mod = 0.0f; change = 0.0f; shift = 0.0f; ancestor = this; lmostSibling = null; this.siblingIndex = siblingIndex; thread = null; }
// -------------------------------------------------------------------------------------------- static void ProcessNode(PositionEvalCache cache, MCTSNode node, float weightEmpirical, bool saveToCache, bool rewriteNodeInTree) { Span <MCTSNodeStructChild> children = node.Ref.Children; // TODO: optimize this away if saveToCache is false ushort[] probabilities = new ushort[node.NumPolicyMoves]; ushort[] indices = new ushort[node.NumPolicyMoves]; // Compute empirical visit distribution float[] nodeFractions = new float[node.NumPolicyMoves]; for (int i = 0; i < node.NumChildrenExpanded; i++) { nodeFractions[i] = (float)node.ChildAtIndex(i).N / (float)node.N; } // Determine P of first unexpanded node // We can't allow any child to have a new P less than this // since we need to keep them in order by P and the resorting logic below // can only operate over expanded nodes float minP = 0; if (node.NumChildrenExpanded < node.NumPolicyMoves) { minP = node.ChildAtIndexInfo(node.NumChildrenExpanded).p; } // Add each move to the policy vector with blend of prior and empirical values for (int i = 0; i < node.NumChildrenExpanded; i++) { (MCTSNode node, EncodedMove move, FP16 p)info = node.ChildAtIndexInfo(i); indices[i] = (ushort)info.move.IndexNeuralNet; float newValue = (1.0f - weightEmpirical) * info.p + weightEmpirical * nodeFractions[i]; if (newValue < minP) { newValue = minP; } probabilities[i] = CompressedPolicyVector.EncodedProbability(newValue); if (rewriteNodeInTree && weightEmpirical != 0) { MCTSNodeStructChild thisChild = children[i]; if (thisChild.IsExpanded) { ref MCTSNodeStruct childNodeRef = ref thisChild.ChildRef; thisChild.ChildRef.P = (FP16)newValue; } else { node.Ref.ChildAtIndex(i).SetUnexpandedPolicyValues(thisChild.Move, (FP16)newValue); } } }
/// <summary> /// Prunes cache down to approximately specified target size. /// </summary> /// <param name="store"></param> /// <param name="targetSize">target numer of nodes, or -1 to use default sizing</param> /// <returns></returns> internal int Prune(MCTSNodeStore store, int targetSize) { int startNumInUse = numInUse; // Default target size is 70% of maximum. if (targetSize == -1) { targetSize = (nodes.Length * 70) / 100; } if (numInUse <= targetSize) { return(0); } lock (lockObj) { int count = 0; for (int i = 0; i < nodes.Length; i++) { // TODO: the long is cast to int, could we possibly overflow? make long? if (nodes[i] != null) { pruneSequenceNums[count++] = (int)nodes[i].LastAccessedSequenceCounter; } } Span <int> slice = new Span <int>(pruneSequenceNums).Slice(0, count); // Compute the minimum sequence number an entry must have // to be retained (to enforce LRU eviction) //float cutoff = KthSmallestValue.CalcKthSmallestValue(keyPrioritiesForSorting, numToPrune); int cutoff; cutoff = KthSmallestValueInt.CalcKthSmallestValue(slice, numInUse - targetSize); //Console.WriteLine(slice.Length + " " + (numInUse-targetSize) + " --> " // + cutoff + " correct " + slice[numInUse-targetSize] + " avg " + slice[numInUse/2]); int maxEntries = pruneCount == 0 ? (numInUse + 1) : nodes.Length; for (int i = 1; i < maxEntries; i++) { MCTSNode node = nodes[i]; if (node != null && node.LastAccessedSequenceCounter < cutoff) { MCTSNodeStructIndex nodeIndex = new MCTSNodeStructIndex(node.Index); nodes[i] = null; ref MCTSNodeStruct refNode = ref store.Nodes.nodes[nodeIndex.Index]; refNode.CacheIndex = 0; numInUse--; } } pruneCount++; }
/// <summary> /// Plots the search tree and opens it in image viewer. /// </summary> /// <param name="rawNode"></param> public static void Show(MCTSNodeStruct rawNode) { string file = Path.GetTempFileName() + ".png"; Save(rawNode, file); new Process { StartInfo = new ProcessStartInfo(file) { UseShellExecute = true } }.Start(); }
protected override LeafEvaluationResult DoTryEvaluate(MCTSNode node) { VerifyCompatibleNetworkDefinition(node); if (OtherContext.Tree.TranspositionRoots.TryGetValue(node.Ref.ZobristHash, out int nodeIndex)) { using (new SearchContextExecutionBlock(OtherContext)) { ref MCTSNodeStruct otherNodeRef = ref OtherContext.Tree.Store.Nodes.nodes[nodeIndex]; CompressedPolicyVector[] cpvArray = new CompressedPolicyVector[1]; if (otherNodeRef.Terminal != Chess.GameResult.Unknown) { NumMisses++; return(default);
/// <summary> /// Calculate tree layout. /// </summary> internal static (DrawTreeNode, DrawTreeInfo) Layout(MCTSNodeStruct root) { DrawTreeInfo treeInfo = new DrawTreeInfo(); int id = 0; DrawTreeNode drawRoot = new DrawTreeNode(null, root, 0, 0, ref id); drawRoot.FirstWalk(); float min = drawRoot.SecondWalk(0.0f, float.MaxValue); float maxX = float.MinValue; float maxDepth = float.MinValue; List <int> nodesPerDepth = new List <int>(); // Shift whole tree so that min x is 0. drawRoot.ThirdWalk(-min, treeInfo); return(drawRoot, treeInfo); }
/// <summary> /// Returns the node which is the root of the cluster (possibly same as node). /// </summary> /// <param name="node"></param> /// <returns></returns> public static int CheckAddToCluster(MCTSNode node) { int rootIndex; if (!node.Context.Tree.TranspositionRoots.TryGetValue(node.Annotation.PositionHashForCaching, out rootIndex)) { throw new Exception("Internal error"); #if NOT // We are the new root, just add to roots table and exit bool added = node.Context.TranspositionRoots.TryAdd(node.Annotation.PositionHashForCaching, node.Index); // If we failed to add, this means this node was already added in the interim // Therefore recursively call ourself so that we can get our self added to the end of the list if (!added) { return(CheckAddToCluster(node)); } else { return(node.Index); } #endif } else { // Cluster already exists. Apppend ourself ref MCTSNodeStruct traverseRef = ref node.Context.Tree.Store.Nodes.nodes[rootIndex]; while (true) { if (traverseRef.NextTranspositionLinked == 0) { break; } traverseRef = ref node.Context.Tree.Store.Nodes.nodes[traverseRef.NextTranspositionLinked]; } // Tack ourself onto the end // TODO: could we more efficiently put ourself at beginning? // TODO: concurrency? if (traverseRef.Index.Index != node.Index) { traverseRef.NextTranspositionLinked = node.Index; } return(rootIndex); }
private static LeafEvaluationResult ExtractTranspositionNodesFromSubtree(MCTSNode node, ref MCTSNodeStruct transpositionRootNode, ref int numAlreadyLinked, MCTSNodeTranspositionVisitor linkedVisitor) { LeafEvaluationResult result = default; // Determine how many evaluations we should extract (based on number requested and number available) int numAvailable = linkedVisitor.TranspositionRootNWhenVisitsStarted - numAlreadyLinked; int numDesired = node.NInFlight + node.NInFlight2; if (numDesired > numAvailable && WARN_COUNT < 10) { Console.WriteLine(numDesired + " Warning: multiple nodes were requested from the transposition subtree, available " + numAvailable); WARN_COUNT++; } int numToFetch = Math.Min(numDesired, numAvailable); Debug.Assert(numToFetch > 0); // Extract each evaluation for (int i = 0; i < numToFetch; i++) { MCTSNodeStructIndex transpositionSubnodeIndex = linkedVisitor.Visitor.GetNext(); Debug.Assert(!transpositionSubnodeIndex.IsNull); NumExtractedAndNeverCloned++; numAlreadyLinked++; // Prepare the result to return ref MCTSNodeStruct transpositionSubnode = ref node.Context.Tree.Store.Nodes.nodes[transpositionSubnodeIndex.Index]; LeafEvaluationResult thisResult = new LeafEvaluationResult(transpositionSubnode.Terminal, transpositionRootNode.WinP, transpositionRootNode.LossP, transpositionRootNode.MPosition); // Update our result node to include this node result = AddResultToResults(result, numToFetch, i, thisResult); if (VERBOSE) { Console.WriteLine($"ProcessAlreadyLinked {node.Index} yields {result.WinP} {result.LossP} via linked subnode root {transpositionRootNode.Index.Index} {transpositionRootNode} chose {transpositionSubnode.Index.Index}"); } node.Ref.NumNodesTranspositionExtracted++; }
public static void DumpAllNodes(MCTSIterator context, ref MCTSNodeStruct node, Base.DataType.Trees.TreeTraversalType type = Base.DataType.Trees.TreeTraversalType.BreadthFirst, bool childDetail = false) { int index = 1; // Visit all nodes and verify various conditions are true node.Traverse(context.Tree.Store, (ref MCTSNodeStruct node) => { Console.WriteLine(index + " " + node); if (childDetail) { int childIndex = 0; foreach (MCTSNodeStructChild childInfo in node.Children) { Console.WriteLine($" {childIndex++,3} {childInfo}"); } } index++; return(true); }, type); }
static void DoMakeChildNewRoot(MCTSNodeStore store, float policySoftmax, ref MCTSNodeStruct newRootChild, PositionWithHistory newPriorMoves, PositionEvalCache cacheNonRetainedNodes, TranspositionRootsDict transpositionRoots) { ChildStartIndexToNodeIndex[] childrenToNodes; uint numNodesUsed; uint numChildrenUsed; BitArray includedNodes; int newRootChildIndex = newRootChild.Index.Index; int newIndexOfNewParent = -1; int nextAvailableNodeIndex = 1; // Traverse this subtree, building a bit array of visited nodes includedNodes = MCTSNodeStructUtils.BitArrayNodesInSubtree(store, ref newRootChild, out numNodesUsed); //using (new TimingBlock("Build position cache ")) if (cacheNonRetainedNodes != null) { long estNumNodes = store.RootNode.N - numNodesUsed; cacheNonRetainedNodes.InitializeWithSize((int)estNumNodes); ExtractPositionCacheNonRetainedNodes(store, policySoftmax, includedNodes, in newRootChild, cacheNonRetainedNodes); } // We will constract a table indicating the starting index and length of // children associated with the nodes we are extracting childrenToNodes = GC.AllocateUninitializedArray <ChildStartIndexToNodeIndex>((int)numNodesUsed); void RewriteNodes() { // TODO: Consider that the above is possibly all we need to do in some case // Suppose the subtree is very large relative to the whole // This approach would be much faster, and orphan an only small part of the storage // Now scan all above nodes. // If they don't belong, ignore. // If they do belong, swap them down to the next available lower location // Note that this can't be parallelized, since we have to do it strictly in order of node index int numRewrittenNodesDone = 0; for (int i = 2; i < store.Nodes.nextFreeIndex; i++) { if (includedNodes.Get(i)) { ref MCTSNodeStruct thisNode = ref store.Nodes.nodes[i]; // Reset any cache entry thisNode.CacheIndex = 0; // Not possible to support transposition linked nodes, // since the root may be in a part of the tree that is not retained // and possibly already overwritten. // We expect them to have already been materialized by the time we reach this point. Debug.Assert(!thisNode.IsTranspositionLinked); Debug.Assert(thisNode.NumNodesTranspositionExtracted == 0); // Remember this location if this is the new parent if (i == newRootChildIndex) { newIndexOfNewParent = nextAvailableNodeIndex; } // Move the actual node MoveNodePosition(store, new MCTSNodeStructIndex(i), new MCTSNodeStructIndex(nextAvailableNodeIndex)); // Reset all transposition information thisNode.NextTranspositionLinked = 0; childrenToNodes[numRewrittenNodesDone] = new ChildStartIndexToNodeIndex(thisNode.childStartBlockIndex, nextAvailableNodeIndex, thisNode.NumPolicyMoves); // Re-insert this into the transpositionRoots (with the updated node index) if (transpositionRoots != null) { transpositionRoots.TryAdd(thisNode.ZobristHash, nextAvailableNodeIndex); } Debug.Assert(thisNode.NumNodesTranspositionExtracted == 0); numRewrittenNodesDone++; nextAvailableNodeIndex++; } } }