/// <summary> /// Constructor for selector over specified MCTSIterator. /// </summary> /// <param name="context"></param> /// <param name="selectorID"></param> /// <param name="priorSequence"></param> /// <param name="guessNumLeaves"></param> public LeafSelectorMulti(MCTSIterator context, int selectorID, PositionWithHistory priorSequence, int guessNumLeaves) { Debug.Assert(selectorID < ILeafSelector.MAX_SELECTORS); if (USE_CUSTOM_THREADPOOL) { tpm = tpmPool.Value.GetFromPool(); } SelectorID = selectorID; PriorSequence = priorSequence; paramsExecution = context.ParamsSearch.Execution; int maxNodesPerBatchForRootPreload = context.ParamsSearch.Execution.RootPreloadDepth > 0 ? MCTSSearchFlow.MAX_PRELOAD_NODES_PER_BATCH : 0; int extraLeafsDynamic = 0; if (context.ParamsSearch.PaddedBatchSizing) { extraLeafsDynamic = context.ParamsSearch.PaddedExtraNodesBase + (int)(context.ParamsSearch.PaddedExtraNodesMultiplier * guessNumLeaves); } leafs = new ListBounded <MCTSNode>(guessNumLeaves + maxNodesPerBatchForRootPreload + extraLeafsDynamic); Context = context; }
/// <summary> /// Applies the results for all nodes (originating from a specified selector). /// </summary> /// <param name="selectorID"></param> /// <param name="nodes"></param> internal void Apply(int selectorID, ListBounded <MCTSNode> nodes) { if (nodes.Count > 0) { DoApply(selectorID, nodes); NumNodesApplied += nodes.Count; NumBatchesApplied++; TotalNumNodesApplied += nodes.Count; } }
public ListBounded <MCTSNode> Evaluate(MCTSIterator context, ListBounded <MCTSNode> nodes) { nodes[0].Context.NumNNBatches++; nodes[0].Context.NumNNNodes += nodes.Count; NUM_EVALUATED += nodes.Count; if (resultTarget == LeafEvaluatorNN.EvalResultTarget.PrimaryEvalResult) { Debug.Assert(nodes[0].EvalResult.IsNull); // null evaluator indicates should have been sent here } else if (resultTarget == LeafEvaluatorNN.EvalResultTarget.SecondaryEvalResult) { Debug.Assert(nodes[0].EvalResultSecondary.IsNull); // null evaluator indicates should have been sent here } Evaluator.BatchGenerate(context, nodes.AsSpan, resultTarget); return(nodes); }
/// <summary> /// Coordinates (possibly parallelized) application of /// evauation results for all nodes in a specified batch. /// </summary> /// <param name="selectorID"></param> /// <param name="batchlet"></param> void DoApply(int selectorID, ListBounded <MCTSNode> batchlet) { DebugVerifyNoDuplicatesAndInFlight(batchlet); TOTAL_APPLIED += batchlet.Count; if (batchlet.Count == 0) { return; } MCTSIterator context = batchlet[0].Context; if (batchlet.Count > context.ParamsSearch.Execution.SetPoliciesNumPoliciesPerThread) { Parallel.Invoke( () => { DoApplySetPolicies(batchlet); }, () => { using (new SearchContextExecutionBlock(context)) DoApplyBackup(selectorID, batchlet); }); } else { DoApplySetPolicies(batchlet); DoApplyBackup(selectorID, batchlet); } #if CRASHES // The main two operations to be performed are independently and // can possibly be performed in parallel const int PARALLEL_THRESHOLD = MCTSParamsFixed.APPLY_NUM_POLICIES_PER_THREAD + (MCTSParamsFixed.APPLY_NUM_POLICIES_PER_THREAD / 2); if (false && MCTSParamsFixed.APPLY_PARALLEL_ENABLED && batchlet.Count > PARALLEL_THRESHOLD) { Parallel.Invoke ( () => DoApplySetPolicies(batchlet), () => DoApplyBackup(batchlet) ); } else { DoApplySetPolicies(batchlet); // must go first DoApplyBackup(batchlet); } //foreach (var node in batchlet.Nodes) node.EvalResult = null; #endif }
public static void DebugVerifyNoDuplicatesAndInFlight(ListBounded <MCTSNode> nodes) { HashSet <MCTSNode> nodesSet = new HashSet <MCTSNode>(nodes.Count); foreach (MCTSNode node in nodes) { if (node.ActionType == MCTSNode.NodeActionType.MCTSApply && !node.Terminal.IsTerminal() && node.NInFlight == 0 && node.NInFlight2 == 0) { throw new Exception($"Internal error: node was generated but not marked in flight"); } if (nodesSet.Contains(node)) { throw new Exception($"Internal error: duplicate node found in Apply { node }"); } else { nodesSet.Add(node); } } }
// Algorithm // priorEvaluateTask <- null // priorNodesNN <- null // do // { // // Select new nodes // newNodes <- Select() // // // Remove any which may have been already selected by alternate selector // newNodes <- Deduplicate(newNodes, priorNodesNN) // // // Check for those that can be immediately evaluated and split them out // (newNodesNN, newNodesImm) <- TryEvalImmediateAndPartition(newNodes) // if (OUT_OF_ORDER_ENABLED) BackupApply(newNodesImm) // // // Launch evaluation of new nodes which need NN evaluation // newEvaluateTask <- new Task(Evaluate(newNodesNN)) // // // Wait for prior NN evaluation to finish and apply nodes // if (priorEvaluateTask != null) // { // priorNodesNN <- Wait(priorEvaluateTask) // BackupApply(priorNodesNN) // } // // if (!OUT_OF_ORDER_ENABLED) BackupApply(newNodesImm) // // // Prepare to cycle again // priorEvaluateTask <- newEvaluateTask // } until (end of search) // // // Finalize last batch // priorNodesNN <- Wait(priorEvaluateTask) // BackupApply(priorNodesNN) public void ProcessDirectOverlapped(MCTSManager manager, int hardLimitNumNodes, int startingBatchSequenceNum, int?forceBatchSize) { Debug.Assert(!manager.Root.IsInFlight); if (hardLimitNumNodes == 0) { hardLimitNumNodes = 1; } bool overlappingAllowed = Context.ParamsSearch.Execution.FlowDirectOverlapped; int initialRootN = Context.Root.N; int guessMaxNumLeaves = MCTSParamsFixed.MAX_NN_BATCH_SIZE; ILeafSelector selector1; ILeafSelector selector2; selector1 = new LeafSelectorMulti(Context, 0, Context.StartPosAndPriorMoves, guessMaxNumLeaves); int secondSelectorID = Context.ParamsSearch.Execution.FlowDualSelectors ? 1 : 0; selector2 = overlappingAllowed ? new LeafSelectorMulti(Context, secondSelectorID, Context.StartPosAndPriorMoves, guessMaxNumLeaves) : null; MCTSNodesSelectedSet[] nodesSelectedSets = new MCTSNodesSelectedSet[overlappingAllowed ? 2 : 1]; for (int i = 0; i < nodesSelectedSets.Length; i++) { nodesSelectedSets[i] = new MCTSNodesSelectedSet(Context, i == 0 ? (LeafSelectorMulti)selector1 : (LeafSelectorMulti)selector2, guessMaxNumLeaves, guessMaxNumLeaves, BlockApply, Context.ParamsSearch.Execution.InFlightThisBatchLinkageEnabled, Context.ParamsSearch.Execution.InFlightOtherBatchLinkageEnabled); } int selectorID = 0; int batchSequenceNum = startingBatchSequenceNum; Task <MCTSNodesSelectedSet> overlappingTask = null; MCTSNodesSelectedSet pendingOverlappedNodes = null; int numOverlappedNodesImmediateApplied = 0; int iterationCount = 0; int numSelected = 0; int nodesLastSecondaryNetEvaluation = 0; while (true) { // Only start overlapping past 1000 nodes because // CPU latency will be very small at small tree sizes, // obviating the overlapping beneifts of hiding this latency bool overlapThisSet = overlappingAllowed && Context.Root.N > 2000; iterationCount++; ILeafSelector selector = selectorID == 0 ? selector1 : selector2; float thisBatchDynamicVLossBoost = batchingManagers[selectorID].VLossDynamicBoostForSelector(); // Call progress callback and check if reached search limit Context.ProgressCallback?.Invoke(manager); Manager.UpdateSearchStopStatus(); if (Manager.StopStatus != MCTSManager.SearchStopStatus.Continue) { break; } int numCurrentlyOverlapped = Context.Root.NInFlight + Context.Root.NInFlight2; int numApplied = Context.Root.N - initialRootN; int hardLimitNumNodesThisBatch = int.MaxValue; if (hardLimitNumNodes > 0) { // Subtract out number already applied or in flight hardLimitNumNodesThisBatch = hardLimitNumNodes - (numApplied + numCurrentlyOverlapped); // Stop search if we have already exceeded search limit // or if remaining number is very small relative to full search // (this avoids incurring latency with a few small batches at end of a search). if (hardLimitNumNodesThisBatch <= numApplied / 1000) { break; } } // Console.WriteLine($"Remap {targetThisBatch} ==> {Context.Root.N} {TargetBatchSize(Context.EstimatedNumSearchNodes, Context.Root.N)}"); int targetThisBatch = OptimalBatchSizeCalculator.CalcOptimalBatchSize(Manager.EstimatedNumSearchNodes, Context.Root.N, overlapThisSet, Context.ParamsSearch.Execution.FlowDualSelectors, Context.ParamsSearch.Execution.MaxBatchSize, Context.ParamsSearch.BatchSizeMultiplier); targetThisBatch = Math.Min(targetThisBatch, Manager.MaxBatchSizeDueToPossibleNearTimeExhaustion); if (forceBatchSize.HasValue) { targetThisBatch = forceBatchSize.Value; } if (targetThisBatch > hardLimitNumNodesThisBatch) { targetThisBatch = hardLimitNumNodesThisBatch; } int thisBatchTotalNumLeafsTargeted = 0; // Compute number of dynamic nodes to add (do not add any when tree is very small and impure child selection is particularly deleterious) int numNodesPadding = 0; if (manager.Root.N > 50 && manager.Context.ParamsSearch.PaddedBatchSizing) { numNodesPadding = manager.Context.ParamsSearch.PaddedExtraNodesBase + (int)(targetThisBatch * manager.Context.ParamsSearch.PaddedExtraNodesMultiplier); } int numVisitsTryThisBatch = targetThisBatch + numNodesPadding; numVisitsTryThisBatch = (int)(numVisitsTryThisBatch * batchingManagers[selectorID].BatchSizeDynamicScaleForSelector()); // Select a batch using this selector // It will select a set of Leafs completely independent of what a possibly other selector already selected // It may find some unevaluated leafs in the tree (extant but N = 0) due to action of the other selector // These leafs will nevertheless be recorded but specifically ignored later MCTSNodesSelectedSet nodesSelectedSet = nodesSelectedSets[selectorID]; nodesSelectedSet.Reset(pendingOverlappedNodes); // Select the batch of nodes if (numVisitsTryThisBatch < 5 || !Context.ParamsSearch.Execution.FlowSplitSelects) { thisBatchTotalNumLeafsTargeted += numVisitsTryThisBatch; ListBounded <MCTSNode> selectedNodes = selector.SelectNewLeafBatchlet(Context.Root, numVisitsTryThisBatch, thisBatchDynamicVLossBoost); nodesSelectedSet.AddSelectedNodes(selectedNodes, true); } else { // Set default assumed max batch size nodesSelectedSet.MaxNodesNN = numVisitsTryThisBatch; // In first attempt try to get 60% of target int numTry1 = Math.Max(1, (int)(numVisitsTryThisBatch * 0.60f)); int numTry2 = (int)(numVisitsTryThisBatch * 0.40f); thisBatchTotalNumLeafsTargeted += numTry1; ListBounded <MCTSNode> selectedNodes1 = selector.SelectNewLeafBatchlet(Context.Root, numTry1, thisBatchDynamicVLossBoost); nodesSelectedSet.AddSelectedNodes(selectedNodes1, true); int numGot1 = nodesSelectedSet.NumNewLeafsAddedNonDuplicates; nodesSelectedSet.ApplyImmeditateNotYetApplied(); // In second try target remaining 40% if (Context.ParamsSearch.Execution.SmartSizeBatches && Context.EvaluatorDef.NumDevices == 1 && Context.NNEvaluators.PerfStatsPrimary != null) // TODO: somehow handle this for multiple GPUs { int[] optimalBatchSizeBreaks; if (Context.NNEvaluators.PerfStatsPrimary.Breaks != null) { optimalBatchSizeBreaks = Context.NNEvaluators.PerfStatsPrimary.Breaks; } else { optimalBatchSizeBreaks = Context.GetOptimalBatchSizeBreaks(Context.EvaluatorDef.DeviceIndices[0]); } // Make an educated guess about the total number of NN nodes that will be sent // to the NN (resulting from both try1 and try2) // We base this on the fraction of nodes in try1 which actually are going to NN // then discounted by 0.8 because the yield on the second try is typically lower const float TRY2_SUCCESS_DISCOUNT_FACTOR = 0.8f; float fracNodesFirstTryGoingToNN = (float)nodesSelectedSet.NodesNN.Count / (float)numTry1; int estimatedAdditionalNNNodesTry2 = (int)(numTry2 * fracNodesFirstTryGoingToNN * TRY2_SUCCESS_DISCOUNT_FACTOR); int estimatedTotalNNNodes = nodesSelectedSet.NodesNN.Count + estimatedAdditionalNNNodesTry2; const float NEARBY_BREAK_FRACTION = 0.20f; int? closeByBreak = NearbyBreak(optimalBatchSizeBreaks, estimatedTotalNNNodes, NEARBY_BREAK_FRACTION); if (closeByBreak is not null) { nodesSelectedSet.MaxNodesNN = closeByBreak.Value; } } // Only try to collect the second half of the batch if the first one yielded // a good fraction of desired nodes (otherwise too many collisions to profitably continue) const float THRESHOLD_SUCCESS_TRY1 = 0.667f; bool shouldProcessTry2 = numTry1 < 10 || ((float)numGot1 / (float)numTry1) >= THRESHOLD_SUCCESS_TRY1; if (shouldProcessTry2) { thisBatchTotalNumLeafsTargeted += numTry2; ListBounded <MCTSNode> selectedNodes2 = selector.SelectNewLeafBatchlet(Context.Root, numTry2, thisBatchDynamicVLossBoost); // TODO: clean this up // - Note that ideally we might not apply immeidate nodes here (i.e. pass false instead of true in next line) // - This is because once done selecting nodes for this batch, we want to get it launched as soon as possible, // we could defer and call ApplyImmeditateNotYetApplied only later (below) // *** WARNING*** However, setting this to false causes NInFlight errors (seen when running test matches within 1 or 2 minutes) nodesSelectedSet.AddSelectedNodes(selectedNodes2, true); // MUST BE true; see above } } // Possibly pad with "preload nodes" if (rootPreloader != null && nodesSelectedSet.NodesNN.Count <= MCTSRootPreloader.PRELOAD_THRESHOLD_BATCH_SIZE) { // TODO: do we need to update thisBatchTotalNumLeafsTargeted ? TryAddRootPreloadNodes(manager, MAX_PRELOAD_NODES_PER_BATCH, nodesSelectedSet, selector); } // TODO: make flow private belows if (Context.EvaluatorDef.SECONDARY_NETWORK_ID != null && (manager.Root.N - nodesLastSecondaryNetEvaluation > 500)) { manager.RunSecondaryNetEvaluations(8, manager.flow.BlockNNEvalSecondaryNet); nodesLastSecondaryNetEvaluation = manager.Root.N; } // Update statistics UpdateStatistics(selectorID, thisBatchTotalNumLeafsTargeted, nodesSelectedSet); // Convert any excess nodes to CacheOnly if (Context.ParamsSearch.PaddedBatchSizing) { throw new Exception("Needs remediation"); // Mark nodes not eligible to be applied as "cache only" //for (int i = numApplyThisBatch; i < selectedNodes.Count; i++) // selectedNodes[i].ActionType = MCTSNode.NodeActionType.CacheOnly; } CeresEnvironment.LogInfo("MCTS", "Batch", $"Batch Target={numVisitsTryThisBatch} " + $"yields NN={nodesSelectedSet.NodesNN.Count} Immediate= {nodesSelectedSet.NodesImmediateNotYetApplied.Count} " + $"[CacheOnly={nodesSelectedSet.NumCacheOnly} None={nodesSelectedSet.NumNotApply}]", manager.InstanceID); // Now launch NN evaluation on the non-immediate nodes bool isPrimary = selectorID == 0; if (overlapThisSet) { Task <MCTSNodesSelectedSet> priorOverlappingTask = overlappingTask; numOverlappedNodesImmediateApplied = nodesSelectedSet.NodesImmediateNotYetApplied.Count; // Launch a new task to preprocess and evaluate these nodes overlappingTask = Task.Run(() => LaunchEvaluate(manager, targetThisBatch, isPrimary, nodesSelectedSet)); nodesSelectedSet.ApplyImmeditateNotYetApplied(); pendingOverlappedNodes = nodesSelectedSet; WaitEvaluationDoneAndApply(priorOverlappingTask, nodesSelectedSet.NodesNN.Count); } else { LaunchEvaluate(manager, targetThisBatch, isPrimary, nodesSelectedSet); nodesSelectedSet.ApplyAll(); //Console.WriteLine("applied " + selector.Leafs.Count + " " + manager.Root); } RunPeriodicMaintenance(manager, batchSequenceNum, iterationCount); // Advance (rotate) selector if (overlappingAllowed) { selectorID = (selectorID + 1) % 2; } batchSequenceNum++; } WaitEvaluationDoneAndApply(overlappingTask); // Debug.Assert(!manager.Root.IsInFlight); if ((manager.Root.NInFlight != 0 || manager.Root.NInFlight2 != 0) && !haveWarned) { Console.WriteLine($"Internal error: search ended with N={manager.Root.N} NInFlight={manager.Root.NInFlight} NInFlight2={manager.Root.NInFlight2} " + manager.Root); int count = 0; manager.Root.Ref.TraverseSequential(manager.Root.Context.Tree.Store, delegate(ref MCTSNodeStruct node, MCTSNodeStructIndex index) { if (node.IsInFlight && node.NumChildrenVisited == 0 && count++ < 20) { Console.WriteLine(" " + index.Index + " " + node.Terminal + " " + node.N + " " + node.IsTranspositionLinked + " " + node.NumNodesTranspositionExtracted); } return(true); }); haveWarned = true; } selector1.Shutdown(); selector2?.Shutdown(); }