示例#1
0
        // --------------------------------------------------------------------------------------------
        public void RecordVisitToTopLevelMove(MCTSNode leafNode, MCTSNodeStructIndex indexOfChildDescendentFromRoot, LeafEvaluationResult evalResult)
        {
#if NOT
            if (indexOfChildDescendentFromRoot == -1)
            {
                return;                                 // if the root
            }
            int childIndex = -1;
            Span <MCTSNodeStructChild> children = Root.Ref.Children;
            for (int i = 0; i < Root.NumChildrenExpanded; i++)
            {
                if (children[i].ChildIndex.Index == indexOfChildDescendentFromRoot)
                {
                    childIndex = i;
                    break;
                }
            }
            if (childIndex == -1)
            {
                throw new Exception("Internal error");
            }

            if (rootChildrenMovingAverageValues == null)
            {
                rootChildrenMovingAverageN             = new float[Root.NumPolicyMoves];
                rootChildrenMovingAverageValues        = new float[Root.NumPolicyMoves];
                rootChildrenMovingAverageSquaredValues = new float[Root.NumPolicyMoves];
            }

            float v    = evalResult.V * (leafNode.Depth % 2 == 0 ? -1 : 1);
            float diff = v - (float)Root.Ref.Children[childIndex].ChildRef.Q;

            const float C1 = 0.99f;
            const float C2 = 1.0f - C1;
示例#2
0
        /// <summary>
        /// Constructor to create a store of specified maximum size.
        /// </summary>
        /// <param name="maxNodes"></param>
        /// <param name="priorMoves"></param>
        public MCTSNodeStore(int maxNodes, PositionWithHistory priorMoves = null)
        {
            if (priorMoves == null)
            {
                priorMoves = PositionWithHistory.StartPosition;
            }

            MaxNodes = maxNodes;
            int allocNodes = maxNodes;

            Nodes = new MCTSNodeStructStorage(allocNodes, null,
                                              MCTSParamsFixed.STORAGE_USE_INCREMENTAL_ALLOC,
                                              MCTSParamsFixed.STORAGE_LARGE_PAGES,
                                              MCTSParamsFixed.STORAGE_USE_EXISTING_SHARED_MEM);

            long reserveChildren = maxNodes * (long)AVG_CHILDREN_PER_NODE;

            Children = new MCTSNodeStructChildStorage(this, reserveChildren);

            // Save a copy of the prior moves
            Nodes.PriorMoves = new PositionWithHistory(priorMoves);

            CeresEnvironment.LogInfo("NodeStore", "Init", $"MCTSNodeStore created with max {maxNodes} nodes, max {reserveChildren} children");

            MCTSNodeStruct.ValidateMCTSNodeStruct();
            RootIndex = new MCTSNodeStructIndex(1);
        }
示例#3
0
 internal ChildEnumeratorImplState(MCTSNodeStore store, MCTSNodeStructIndex nodeIndex, int maxIndex)
 {
     Store     = store;
     childSpan = store.Children.SpanForNode(nodeIndex);
     index     = UInt32.MaxValue;
     endIndex  = (int)maxIndex;
     nodes     = (MCTSNodeStruct *)Store.Nodes.nodes.RawMemory;
 }
示例#4
0
        /// <summary>
        /// Constructor to begin iteration within specified at a specified node.
        /// </summary>
        /// <param name="store"></param>
        /// <param name="root"></param>
        public MCTSNodeSequentialVisitor(MCTSNodeStore store, MCTSNodeStructIndex root)
        {
            Store = store;
            Root  = root;

            currentNode     = root;
            pendingBranches = new SortedSet <MCTSNodeStructIndex>();
        }
示例#5
0
        /// <summary>
        /// Returns the MCTSNode having the specified index and stored in the cache
        /// or null if not currently cached.
        /// </summary>
        /// <param name="nodeIndex"></param>
        /// <returns></returns>
        public MCTSNode Lookup(MCTSNodeStructIndex nodeIndex)
        {
            bool alreadyInCache = false;

            MCTSNode cachedItem = default;

            alreadyInCache = nodeCache.TryGetValue(nodeIndex.Index, out cachedItem);

            return(cachedItem);
        }
        /// <summary>
        /// Prunes cache down to approximately specified target size.
        /// </summary>
        /// <param name="store"></param>
        /// <param name="targetSize">target numer of nodes, or -1 to use default sizing</param>
        /// <returns></returns>
        internal int Prune(MCTSNodeStore store, int targetSize)
        {
            int startNumInUse = numInUse;

            // Default target size is 70% of maximum.
            if (targetSize == -1)
            {
                targetSize = (nodes.Length * 70) / 100;
            }
            if (numInUse <= targetSize)
            {
                return(0);
            }

            lock (lockObj)
            {
                int count = 0;
                for (int i = 0; i < nodes.Length; i++)
                {
                    // TODO: the long is cast to int, could we possibly overflow? make long?
                    if (nodes[i] != null)
                    {
                        pruneSequenceNums[count++] = (int)nodes[i].LastAccessedSequenceCounter;
                    }
                }

                Span <int> slice = new Span <int>(pruneSequenceNums).Slice(0, count);

                // Compute the minimum sequence number an entry must have
                // to be retained (to enforce LRU eviction)
                //float cutoff = KthSmallestValue.CalcKthSmallestValue(keyPrioritiesForSorting, numToPrune);
                int cutoff;

                cutoff = KthSmallestValueInt.CalcKthSmallestValue(slice, numInUse - targetSize);
                //Console.WriteLine(slice.Length + " " + (numInUse-targetSize) + " --> "
                //                 + cutoff + " correct " + slice[numInUse-targetSize] + " avg " + slice[numInUse/2]);

                int maxEntries = pruneCount == 0 ? (numInUse + 1) : nodes.Length;
                for (int i = 1; i < maxEntries; i++)
                {
                    MCTSNode node = nodes[i];
                    if (node != null && node.LastAccessedSequenceCounter < cutoff)
                    {
                        MCTSNodeStructIndex nodeIndex = new MCTSNodeStructIndex(node.Index);
                        nodes[i] = null;

                        ref MCTSNodeStruct refNode = ref store.Nodes.nodes[nodeIndex.Index];
                        refNode.CacheIndex = 0;

                        numInUse--;
                    }
                }
                pruneCount++;
            }
示例#7
0
        /// <summary>
        /// Constructor which creates an MCTSNode wrapper for the raw node at specified index.
        /// </summary>
        /// <param name="context"></param>
        /// <param name="index"></param>
        /// <param name="parent">optionally the parent node</param>
        internal MCTSNode(MCTSIterator context, MCTSNodeStructIndex index, MCTSNode parent = null)
        {
            Debug.Assert(context.Tree.Store.Nodes != null);
            Debug.Assert(index.Index <= context.Tree.Store.Nodes.MaxNodes);

            Context = context;
            Tree    = context.Tree;

            this.parent = parent;
            Span <MCTSNodeStruct> parentArray = context.Tree.Store.Nodes.Span;

            ptr        = (MCTSNodeStruct *)Unsafe.AsPointer(ref parentArray[index.Index]);
            this.index = index;
        }
示例#8
0
        private static LeafEvaluationResult ExtractTranspositionNodesFromSubtree(MCTSNode node, ref MCTSNodeStruct transpositionRootNode,
                                                                                 ref int numAlreadyLinked, MCTSNodeTranspositionVisitor linkedVisitor)
        {
            LeafEvaluationResult result = default;

            // Determine how many evaluations we should extract (based on number requested and number available)
            int numAvailable = linkedVisitor.TranspositionRootNWhenVisitsStarted - numAlreadyLinked;
            int numDesired   = node.NInFlight + node.NInFlight2;

            if (numDesired > numAvailable && WARN_COUNT < 10)
            {
                Console.WriteLine(numDesired + " Warning: multiple nodes were requested from the transposition subtree, available " + numAvailable);
                WARN_COUNT++;
            }
            int numToFetch = Math.Min(numDesired, numAvailable);

            Debug.Assert(numToFetch > 0);

            // Extract each evaluation
            for (int i = 0; i < numToFetch; i++)
            {
                MCTSNodeStructIndex transpositionSubnodeIndex = linkedVisitor.Visitor.GetNext();
                Debug.Assert(!transpositionSubnodeIndex.IsNull);

                NumExtractedAndNeverCloned++;
                numAlreadyLinked++;

                // Prepare the result to return
                ref MCTSNodeStruct   transpositionSubnode = ref node.Context.Tree.Store.Nodes.nodes[transpositionSubnodeIndex.Index];
                LeafEvaluationResult thisResult           = new LeafEvaluationResult(transpositionSubnode.Terminal, transpositionRootNode.WinP,
                                                                                     transpositionRootNode.LossP, transpositionRootNode.MPosition);

                // Update our result node to include this node
                result = AddResultToResults(result, numToFetch, i, thisResult);

                if (VERBOSE)
                {
                    Console.WriteLine($"ProcessAlreadyLinked {node.Index} yields {result.WinP} {result.LossP} via linked subnode root {transpositionRootNode.Index.Index} {transpositionRootNode} chose {transpositionSubnode.Index.Index}");
                }

                node.Ref.NumNodesTranspositionExtracted++;
            }
示例#9
0
        /// <summary>
        /// Attempts to return the MCTSNode associated with an annotation in the cache,
        /// or null if not found
        /// </summary>
        /// <param name="nodeIndex"></param>
        /// <returns></returns>
        public MCTSNode GetNode(MCTSNodeStructIndex nodeIndex, MCTSNode parent = null, bool checkCache = true)
        {
            MCTSNode ret = checkCache ? cache.Lookup(nodeIndex) : null;

            if (ret is null)
            {
                if (parent is null && !nodeIndex.IsRoot)
                {
                    parent = GetNode(nodeIndex.Ref.ParentIndex, null);
                }

                ret = new MCTSNode(Context, nodeIndex, parent);
                cache?.Add(ret);
                NUM_MISSES++;
            }
            else
            {
                NUM_HITS++;
                ret.LastAccessedSequenceCounter = SEQUENCE_COUNTER++;
            }

            Debug.Assert(ret.Index == nodeIndex.Index);
            return(ret);
        }
示例#10
0
 /// <summary>
 /// Returns the MCTSNode having the specified index and stored in the cache
 /// or null if not currently cached.
 /// </summary>
 /// <param name="nodeIndex"></param>
 /// <returns></returns>
 public MCTSNode Lookup(MCTSNodeStructIndex nodeIndex) => nodes[nodeIndex.Index];
示例#11
0
 public ChildEnumeratorImpl(MCTSNodeStore store, MCTSNodeStructIndex nodeIndex, int overrideMaxIndex = int.MaxValue)
 {
     Store     = store;
     NodeIndex = nodeIndex;
     MaxIndex  = Math.Min(overrideMaxIndex, Store.Nodes.nodes[nodeIndex.Index].NumChildrenExpanded - 1);
 }
示例#12
0
 /// <summary>
 /// Methods at MCTSNodeStore related to enumeration of children.
 ///
 /// NOTE: this is not currently used but possibly could/should be used.
 /// </summary>
 /// <param name="nodeIndex"></param>
 /// <param name="overrideMaxIndex"></param>
 /// <returns></returns>
 public ChildEnumeratorImpl ChildrenExpandedEnumerator(MCTSNodeStructIndex nodeIndex, int overrideMaxIndex = int.MaxValue)
 => new ChildEnumeratorImpl(this, nodeIndex, overrideMaxIndex);
 /// <summary>
 /// Returns the MCTSNode having the specified index and stored in the cache
 /// or null if not currently cached.
 /// </summary>
 /// <param name="nodeIndex"></param>
 /// <returns></returns>
 public MCTSNode Lookup(MCTSNodeStructIndex nodeIndex)
 => subCaches[nodeIndex.Index % MAX_SUBCACHES].Lookup(nodeIndex);
示例#14
0
        /// <summary>
        /// Returns the next branch to be iterated.
        /// </summary>
        /// <param name="thisNodeIndex"></param>
        /// <param name="boundSeqNum"></param>
        /// <returns></returns>
        MCTSNodeStructIndex GetNextBranch(MCTSNodeStructIndex thisNodeIndex, int boundSeqNum)
        {
            int bestIndex1 = -1;
            int bestSeq    = int.MaxValue;

            ref MCTSNodeStruct thisNode = ref Store.Nodes.nodes[thisNodeIndex.Index];