Beispiel #1
0
        /// <summary>
        /// Constructor for selector over specified MCTSIterator.
        /// </summary>
        /// <param name="context"></param>
        /// <param name="selectorID"></param>
        /// <param name="priorSequence"></param>
        /// <param name="guessNumLeaves"></param>
        public LeafSelectorMulti(MCTSIterator context, int selectorID, PositionWithHistory priorSequence, int guessNumLeaves)
        {
            Debug.Assert(selectorID < ILeafSelector.MAX_SELECTORS);

            if (USE_CUSTOM_THREADPOOL)
            {
                tpm = tpmPool.Value.GetFromPool();
            }

            SelectorID      = selectorID;
            PriorSequence   = priorSequence;
            paramsExecution = context.ParamsSearch.Execution;

            int maxNodesPerBatchForRootPreload = context.ParamsSearch.Execution.RootPreloadDepth > 0 ? MCTSSearchFlow.MAX_PRELOAD_NODES_PER_BATCH : 0;
            int extraLeafsDynamic = 0;

            if (context.ParamsSearch.PaddedBatchSizing)
            {
                extraLeafsDynamic = context.ParamsSearch.PaddedExtraNodesBase + (int)(context.ParamsSearch.PaddedExtraNodesMultiplier * guessNumLeaves);
            }

            leafs = new ListBounded <MCTSNode>(guessNumLeaves + maxNodesPerBatchForRootPreload + extraLeafsDynamic);

            Context = context;
        }
Beispiel #2
0
        /// <summary>
        /// Applies the results for all nodes (originating from a specified selector).
        /// </summary>
        /// <param name="selectorID"></param>
        /// <param name="nodes"></param>
        internal void Apply(int selectorID, ListBounded <MCTSNode> nodes)
        {
            if (nodes.Count > 0)
            {
                DoApply(selectorID, nodes);

                NumNodesApplied += nodes.Count;
                NumBatchesApplied++;
                TotalNumNodesApplied += nodes.Count;
            }
        }
Beispiel #3
0
        public ListBounded <MCTSNode> Evaluate(MCTSIterator context, ListBounded <MCTSNode> nodes)
        {
            nodes[0].Context.NumNNBatches++;
            nodes[0].Context.NumNNNodes += nodes.Count;

            NUM_EVALUATED += nodes.Count;

            if (resultTarget == LeafEvaluatorNN.EvalResultTarget.PrimaryEvalResult)
            {
                Debug.Assert(nodes[0].EvalResult.IsNull); // null evaluator indicates should have been sent here
            }
            else if (resultTarget == LeafEvaluatorNN.EvalResultTarget.SecondaryEvalResult)
            {
                Debug.Assert(nodes[0].EvalResultSecondary.IsNull); // null evaluator indicates should have been sent here
            }
            Evaluator.BatchGenerate(context, nodes.AsSpan, resultTarget);

            return(nodes);
        }
Beispiel #4
0
        /// <summary>
        /// Coordinates (possibly parallelized) application of
        /// evauation results for all nodes in a specified batch.
        /// </summary>
        /// <param name="selectorID"></param>
        /// <param name="batchlet"></param>
        void DoApply(int selectorID, ListBounded <MCTSNode> batchlet)
        {
            DebugVerifyNoDuplicatesAndInFlight(batchlet);

            TOTAL_APPLIED += batchlet.Count;
            if (batchlet.Count == 0)
            {
                return;
            }

            MCTSIterator context = batchlet[0].Context;

            if (batchlet.Count > context.ParamsSearch.Execution.SetPoliciesNumPoliciesPerThread)
            {
                Parallel.Invoke(
                    () => { DoApplySetPolicies(batchlet); },
                    () => { using (new SearchContextExecutionBlock(context)) DoApplyBackup(selectorID, batchlet); });
            }
            else
            {
                DoApplySetPolicies(batchlet);
                DoApplyBackup(selectorID, batchlet);
            }
#if CRASHES
            // The main two operations to be performed are independently and
            // can possibly be performed in parallel
            const int PARALLEL_THRESHOLD = MCTSParamsFixed.APPLY_NUM_POLICIES_PER_THREAD + (MCTSParamsFixed.APPLY_NUM_POLICIES_PER_THREAD / 2);
            if (false && MCTSParamsFixed.APPLY_PARALLEL_ENABLED && batchlet.Count > PARALLEL_THRESHOLD)
            {
                Parallel.Invoke
                (
                    () => DoApplySetPolicies(batchlet),
                    () => DoApplyBackup(batchlet)
                );
            }
            else
            {
                DoApplySetPolicies(batchlet); // must go first
                DoApplyBackup(batchlet);
            }
            //foreach (var node in batchlet.Nodes) node.EvalResult = null;
#endif
        }
Beispiel #5
0
        public static void DebugVerifyNoDuplicatesAndInFlight(ListBounded <MCTSNode> nodes)
        {
            HashSet <MCTSNode> nodesSet = new HashSet <MCTSNode>(nodes.Count);

            foreach (MCTSNode node in nodes)
            {
                if (node.ActionType == MCTSNode.NodeActionType.MCTSApply &&
                    !node.Terminal.IsTerminal() &&
                    node.NInFlight == 0 &&
                    node.NInFlight2 == 0)
                {
                    throw new Exception($"Internal error: node was generated but not marked in flight");
                }
                if (nodesSet.Contains(node))
                {
                    throw new Exception($"Internal error: duplicate node found in Apply  { node }");
                }
                else
                {
                    nodesSet.Add(node);
                }
            }
        }
Beispiel #6
0
        // Algorithm
        //   priorEvaluateTask <- null
        //   priorNodesNN <- null
        //   do
        //   {
        //     // Select new nodes
        //     newNodes <- Select()
        //
        //     // Remove any which may have been already selected by alternate selector
        //     newNodes <- Deduplicate(newNodes, priorNodesNN)
        //
        //     // Check for those that can be immediately evaluated and split them out
        //     (newNodesNN, newNodesImm) <- TryEvalImmediateAndPartition(newNodes)
        //     if (OUT_OF_ORDER_ENABLED) BackupApply(newNodesImm)
        //
        //     // Launch evaluation of new nodes which need NN evaluation
        //     newEvaluateTask <- new Task(Evaluate(newNodesNN))
        //
        //     // Wait for prior NN evaluation to finish and apply nodes
        //     if (priorEvaluateTask != null)
        //     {
        //       priorNodesNN <- Wait(priorEvaluateTask)
        //       BackupApply(priorNodesNN)
        //     }
        //
        //     if (!OUT_OF_ORDER_ENABLED) BackupApply(newNodesImm)
        //
        //     // Prepare to cycle again
        //     priorEvaluateTask <- newEvaluateTask
        //   } until (end of search)
        //
        //   // Finalize last batch
        //   priorNodesNN <- Wait(priorEvaluateTask)
        //   BackupApply(priorNodesNN)


        public void ProcessDirectOverlapped(MCTSManager manager, int hardLimitNumNodes, int startingBatchSequenceNum, int?forceBatchSize)
        {
            Debug.Assert(!manager.Root.IsInFlight);
            if (hardLimitNumNodes == 0)
            {
                hardLimitNumNodes = 1;
            }

            bool overlappingAllowed = Context.ParamsSearch.Execution.FlowDirectOverlapped;
            int  initialRootN       = Context.Root.N;

            int guessMaxNumLeaves = MCTSParamsFixed.MAX_NN_BATCH_SIZE;

            ILeafSelector selector1;
            ILeafSelector selector2;

            selector1 = new LeafSelectorMulti(Context, 0, Context.StartPosAndPriorMoves, guessMaxNumLeaves);
            int secondSelectorID = Context.ParamsSearch.Execution.FlowDualSelectors ? 1 : 0;

            selector2 = overlappingAllowed ? new LeafSelectorMulti(Context, secondSelectorID, Context.StartPosAndPriorMoves, guessMaxNumLeaves) : null;

            MCTSNodesSelectedSet[] nodesSelectedSets = new MCTSNodesSelectedSet[overlappingAllowed ? 2 : 1];
            for (int i = 0; i < nodesSelectedSets.Length; i++)
            {
                nodesSelectedSets[i] = new MCTSNodesSelectedSet(Context,
                                                                i == 0 ? (LeafSelectorMulti)selector1
                                                               : (LeafSelectorMulti)selector2,
                                                                guessMaxNumLeaves, guessMaxNumLeaves, BlockApply,
                                                                Context.ParamsSearch.Execution.InFlightThisBatchLinkageEnabled,
                                                                Context.ParamsSearch.Execution.InFlightOtherBatchLinkageEnabled);
            }

            int selectorID       = 0;
            int batchSequenceNum = startingBatchSequenceNum;

            Task <MCTSNodesSelectedSet> overlappingTask        = null;
            MCTSNodesSelectedSet        pendingOverlappedNodes = null;
            int numOverlappedNodesImmediateApplied             = 0;

            int iterationCount = 0;
            int numSelected    = 0;
            int nodesLastSecondaryNetEvaluation = 0;

            while (true)
            {
                // Only start overlapping past 1000 nodes because
                // CPU latency will be very small at small tree sizes,
                // obviating the overlapping beneifts of hiding this latency
                bool overlapThisSet = overlappingAllowed && Context.Root.N > 2000;

                iterationCount++;

                ILeafSelector selector = selectorID == 0 ? selector1 : selector2;

                float thisBatchDynamicVLossBoost = batchingManagers[selectorID].VLossDynamicBoostForSelector();

                // Call progress callback and check if reached search limit
                Context.ProgressCallback?.Invoke(manager);
                Manager.UpdateSearchStopStatus();
                if (Manager.StopStatus != MCTSManager.SearchStopStatus.Continue)
                {
                    break;
                }

                int numCurrentlyOverlapped = Context.Root.NInFlight + Context.Root.NInFlight2;

                int numApplied = Context.Root.N - initialRootN;
                int hardLimitNumNodesThisBatch = int.MaxValue;
                if (hardLimitNumNodes > 0)
                {
                    // Subtract out number already applied or in flight
                    hardLimitNumNodesThisBatch = hardLimitNumNodes - (numApplied + numCurrentlyOverlapped);

                    // Stop search if we have already exceeded search limit
                    // or if remaining number is very small relative to full search
                    // (this avoids incurring latency with a few small batches at end of a search).
                    if (hardLimitNumNodesThisBatch <= numApplied / 1000)
                    {
                        break;
                    }
                }

                //          Console.WriteLine($"Remap {targetThisBatch} ==> {Context.Root.N} {TargetBatchSize(Context.EstimatedNumSearchNodes, Context.Root.N)}");
                int targetThisBatch = OptimalBatchSizeCalculator.CalcOptimalBatchSize(Manager.EstimatedNumSearchNodes, Context.Root.N,
                                                                                      overlapThisSet,
                                                                                      Context.ParamsSearch.Execution.FlowDualSelectors,
                                                                                      Context.ParamsSearch.Execution.MaxBatchSize,
                                                                                      Context.ParamsSearch.BatchSizeMultiplier);

                targetThisBatch = Math.Min(targetThisBatch, Manager.MaxBatchSizeDueToPossibleNearTimeExhaustion);
                if (forceBatchSize.HasValue)
                {
                    targetThisBatch = forceBatchSize.Value;
                }
                if (targetThisBatch > hardLimitNumNodesThisBatch)
                {
                    targetThisBatch = hardLimitNumNodesThisBatch;
                }

                int thisBatchTotalNumLeafsTargeted = 0;

                // Compute number of dynamic nodes to add (do not add any when tree is very small and impure child selection is particularly deleterious)
                int numNodesPadding = 0;
                if (manager.Root.N > 50 && manager.Context.ParamsSearch.PaddedBatchSizing)
                {
                    numNodesPadding = manager.Context.ParamsSearch.PaddedExtraNodesBase
                                      + (int)(targetThisBatch * manager.Context.ParamsSearch.PaddedExtraNodesMultiplier);
                }
                int numVisitsTryThisBatch = targetThisBatch + numNodesPadding;

                numVisitsTryThisBatch = (int)(numVisitsTryThisBatch * batchingManagers[selectorID].BatchSizeDynamicScaleForSelector());

                // Select a batch using this selector
                // It will select a set of Leafs completely independent of what a possibly other selector already selected
                // It may find some unevaluated leafs in the tree (extant but N = 0) due to action of the other selector
                // These leafs will nevertheless be recorded but specifically ignored later
                MCTSNodesSelectedSet nodesSelectedSet = nodesSelectedSets[selectorID];
                nodesSelectedSet.Reset(pendingOverlappedNodes);

                // Select the batch of nodes
                if (numVisitsTryThisBatch < 5 || !Context.ParamsSearch.Execution.FlowSplitSelects)
                {
                    thisBatchTotalNumLeafsTargeted += numVisitsTryThisBatch;
                    ListBounded <MCTSNode> selectedNodes = selector.SelectNewLeafBatchlet(Context.Root, numVisitsTryThisBatch, thisBatchDynamicVLossBoost);
                    nodesSelectedSet.AddSelectedNodes(selectedNodes, true);
                }
                else
                {
                    // Set default assumed max batch size
                    nodesSelectedSet.MaxNodesNN = numVisitsTryThisBatch;

                    // In first attempt try to get 60% of target
                    int numTry1 = Math.Max(1, (int)(numVisitsTryThisBatch * 0.60f));
                    int numTry2 = (int)(numVisitsTryThisBatch * 0.40f);
                    thisBatchTotalNumLeafsTargeted += numTry1;

                    ListBounded <MCTSNode> selectedNodes1 = selector.SelectNewLeafBatchlet(Context.Root, numTry1, thisBatchDynamicVLossBoost);
                    nodesSelectedSet.AddSelectedNodes(selectedNodes1, true);
                    int numGot1 = nodesSelectedSet.NumNewLeafsAddedNonDuplicates;
                    nodesSelectedSet.ApplyImmeditateNotYetApplied();

                    // In second try target remaining 40%
                    if (Context.ParamsSearch.Execution.SmartSizeBatches &&
                        Context.EvaluatorDef.NumDevices == 1 &&
                        Context.NNEvaluators.PerfStatsPrimary != null) // TODO: somehow handle this for multiple GPUs
                    {
                        int[] optimalBatchSizeBreaks;
                        if (Context.NNEvaluators.PerfStatsPrimary.Breaks != null)
                        {
                            optimalBatchSizeBreaks = Context.NNEvaluators.PerfStatsPrimary.Breaks;
                        }
                        else
                        {
                            optimalBatchSizeBreaks = Context.GetOptimalBatchSizeBreaks(Context.EvaluatorDef.DeviceIndices[0]);
                        }

                        // Make an educated guess about the total number of NN nodes that will be sent
                        // to the NN (resulting from both try1 and try2)
                        // We base this on the fraction of nodes in try1 which actually are going to NN
                        // then discounted by 0.8 because the yield on the second try is typically lower
                        const float TRY2_SUCCESS_DISCOUNT_FACTOR   = 0.8f;
                        float       fracNodesFirstTryGoingToNN     = (float)nodesSelectedSet.NodesNN.Count / (float)numTry1;
                        int         estimatedAdditionalNNNodesTry2 = (int)(numTry2 * fracNodesFirstTryGoingToNN * TRY2_SUCCESS_DISCOUNT_FACTOR);

                        int estimatedTotalNNNodes = nodesSelectedSet.NodesNN.Count + estimatedAdditionalNNNodesTry2;

                        const float NEARBY_BREAK_FRACTION = 0.20f;
                        int?        closeByBreak          = NearbyBreak(optimalBatchSizeBreaks, estimatedTotalNNNodes, NEARBY_BREAK_FRACTION);
                        if (closeByBreak is not null)
                        {
                            nodesSelectedSet.MaxNodesNN = closeByBreak.Value;
                        }
                    }

                    // Only try to collect the second half of the batch if the first one yielded
                    // a good fraction of desired nodes (otherwise too many collisions to profitably continue)
                    const float THRESHOLD_SUCCESS_TRY1 = 0.667f;
                    bool        shouldProcessTry2      = numTry1 < 10 || ((float)numGot1 / (float)numTry1) >= THRESHOLD_SUCCESS_TRY1;
                    if (shouldProcessTry2)
                    {
                        thisBatchTotalNumLeafsTargeted += numTry2;
                        ListBounded <MCTSNode> selectedNodes2 = selector.SelectNewLeafBatchlet(Context.Root, numTry2, thisBatchDynamicVLossBoost);

                        // TODO: clean this up
                        //  - Note that ideally we might not apply immeidate nodes here (i.e. pass false instead of true in next line)
                        //  - This is because once done selecting nodes for this batch, we want to get it launched as soon as possible,
                        //    we could defer and call ApplyImmeditateNotYetApplied only later (below)
                        // *** WARNING*** However, setting this to false causes NInFlight errors (seen when running test matches within 1 or 2 minutes)
                        nodesSelectedSet.AddSelectedNodes(selectedNodes2, true); // MUST BE true; see above
                    }
                }

                // Possibly pad with "preload nodes"
                if (rootPreloader != null && nodesSelectedSet.NodesNN.Count <= MCTSRootPreloader.PRELOAD_THRESHOLD_BATCH_SIZE)
                {
                    // TODO: do we need to update thisBatchTotalNumLeafsTargeted ?
                    TryAddRootPreloadNodes(manager, MAX_PRELOAD_NODES_PER_BATCH, nodesSelectedSet, selector);
                }

                // TODO: make flow private belows
                if (Context.EvaluatorDef.SECONDARY_NETWORK_ID != null && (manager.Root.N - nodesLastSecondaryNetEvaluation > 500))
                {
                    manager.RunSecondaryNetEvaluations(8, manager.flow.BlockNNEvalSecondaryNet);
                    nodesLastSecondaryNetEvaluation = manager.Root.N;
                }

                // Update statistics
                UpdateStatistics(selectorID, thisBatchTotalNumLeafsTargeted, nodesSelectedSet);

                // Convert any excess nodes to CacheOnly
                if (Context.ParamsSearch.PaddedBatchSizing)
                {
                    throw new Exception("Needs remediation");
                    // Mark nodes not eligible to be applied as "cache only"
                    //for (int i = numApplyThisBatch; i < selectedNodes.Count; i++)
                    //  selectedNodes[i].ActionType = MCTSNode.NodeActionType.CacheOnly;
                }

                CeresEnvironment.LogInfo("MCTS", "Batch", $"Batch Target={numVisitsTryThisBatch} "
                                         + $"yields NN={nodesSelectedSet.NodesNN.Count} Immediate= {nodesSelectedSet.NodesImmediateNotYetApplied.Count} "
                                         + $"[CacheOnly={nodesSelectedSet.NumCacheOnly} None={nodesSelectedSet.NumNotApply}]", manager.InstanceID);

                // Now launch NN evaluation on the non-immediate nodes
                bool isPrimary = selectorID == 0;
                if (overlapThisSet)
                {
                    Task <MCTSNodesSelectedSet> priorOverlappingTask = overlappingTask;

                    numOverlappedNodesImmediateApplied = nodesSelectedSet.NodesImmediateNotYetApplied.Count;

                    // Launch a new task to preprocess and evaluate these nodes
                    overlappingTask = Task.Run(() => LaunchEvaluate(manager, targetThisBatch, isPrimary, nodesSelectedSet));
                    nodesSelectedSet.ApplyImmeditateNotYetApplied();
                    pendingOverlappedNodes = nodesSelectedSet;

                    WaitEvaluationDoneAndApply(priorOverlappingTask, nodesSelectedSet.NodesNN.Count);
                }
                else
                {
                    LaunchEvaluate(manager, targetThisBatch, isPrimary, nodesSelectedSet);
                    nodesSelectedSet.ApplyAll();
                    //Console.WriteLine("applied " + selector.Leafs.Count + " " + manager.Root);
                }

                RunPeriodicMaintenance(manager, batchSequenceNum, iterationCount);

                // Advance (rotate) selector
                if (overlappingAllowed)
                {
                    selectorID = (selectorID + 1) % 2;
                }
                batchSequenceNum++;
            }

            WaitEvaluationDoneAndApply(overlappingTask);

            //      Debug.Assert(!manager.Root.IsInFlight);

            if ((manager.Root.NInFlight != 0 || manager.Root.NInFlight2 != 0) && !haveWarned)
            {
                Console.WriteLine($"Internal error: search ended with N={manager.Root.N} NInFlight={manager.Root.NInFlight} NInFlight2={manager.Root.NInFlight2} " + manager.Root);
                int count = 0;
                manager.Root.Ref.TraverseSequential(manager.Root.Context.Tree.Store, delegate(ref MCTSNodeStruct node, MCTSNodeStructIndex index)
                {
                    if (node.IsInFlight && node.NumChildrenVisited == 0 && count++ < 20)
                    {
                        Console.WriteLine("  " + index.Index + " " + node.Terminal + " " + node.N + " " + node.IsTranspositionLinked + " " + node.NumNodesTranspositionExtracted);
                    }
                    return(true);
                });
                haveWarned = true;
            }

            selector1.Shutdown();
            selector2?.Shutdown();
        }