/// <summary> /// Entry point for the score calculation. /// /// Given a set of information about the current node /// the number of visits to be made to the subtree, /// an array of visit counts is returned /// indicating the number of visits that should be made to each child. /// </summary> /// <param name="dualSelectorMode"></param> /// <param name="paramsSelect"></param> /// <param name="selectorID"></param> /// <param name="dynamicVLossBoost"></param> /// <param name="parentIsRoot"></param> /// <param name="parentN"></param> /// <param name="parentNInFlight"></param> /// <param name="qParent"></param> /// <param name="parentSumPVisited"></param> /// <param name="p"></param> /// <param name="w"></param> /// <param name="n"></param> /// <param name="nInFlight"></param> /// <param name="numChildren"></param> /// <param name="numVisitsToCompute"></param> /// <param name="outputScores"></param> /// <param name="outputChildVisitCounts"></param> public static void ScoreCalcMulti(bool dualSelectorMode, ParamsSelect paramsSelect, int selectorID, float dynamicVLossBoost, bool parentIsRoot, float parentN, float parentNInFlight, float qParent, float parentSumPVisited, Span <float> p, Span <float> w, Span <float> n, Span <float> nInFlight, int numChildren, int numVisitsToCompute, Span <float> outputScores, Span <short> outputChildVisitCounts) { // Saving output scores only makes sense when a single visit being computed Debug.Assert(!(outputScores != default && numVisitsToCompute > 1)); Debug.Assert(p.Length == MAX_CHILDREN); Debug.Assert(w.Length == MAX_CHILDREN); Debug.Assert(n.Length == MAX_CHILDREN); Debug.Assert(nInFlight.Length == MAX_CHILDREN); Debug.Assert(numChildren <= MAX_CHILDREN); Debug.Assert(outputScores == default || outputScores.Length >= numChildren); Debug.Assert(outputChildVisitCounts.Length >= numChildren); int numBlocks = (numChildren / 8) + ((numChildren % 8 == 0) ? 0 : 1); float virtualLossMultiplier; if (ParamsSelect.VLossRelative) { virtualLossMultiplier = (qParent + paramsSelect.VirtualLossDefaultRelative + dynamicVLossBoost); } else { virtualLossMultiplier = paramsSelect.VirtualLossDefaultAbsolute; } float cpuctValue = paramsSelect.CalcCPUCT(parentIsRoot, dualSelectorMode, selectorID, parentN); // Compute qWhenNoChildren float fpuValue = -paramsSelect.CalcFPUValue(parentIsRoot); // TODO: to be more precise, parentSumPVisited should possibly be updated as we visit children bool useFPUReduction = paramsSelect.GetFPUMode(parentIsRoot) == ParamsSelect.FPUType.Reduction; float qWhenNoChildren = useFPUReduction ? (+qParent + fpuValue * MathF.Sqrt(parentSumPVisited)) : fpuValue; if (parentIsRoot && parentN > paramsSelect.RootCPUCTExtraMultiplierDivisor && paramsSelect.RootCPUCTExtraMultiplierExponent != 0) { cpuctValue *= MathF.Pow(parentN / paramsSelect.RootCPUCTExtraMultiplierDivisor, paramsSelect.RootCPUCTExtraMultiplierExponent); } Compute(parentN, parentNInFlight, p, w, n, nInFlight, numChildren, numVisitsToCompute, outputScores, outputChildVisitCounts, numBlocks, virtualLossMultiplier, parentIsRoot ? paramsSelect.UCTRootNumeratorExponent : paramsSelect.UCTNonRootNumeratorExponent, cpuctValue, qWhenNoChildren, parentIsRoot ? paramsSelect.UCTRootDenominatorExponent : paramsSelect.UCTNonRootDenominatorExponent); }