示例#1
0
        /// <summary>
        /// Entry point for the score calculation.
        ///
        /// Given a set of information about the current node
        /// the number of visits to be made to the subtree,
        /// an array of visit counts is returned
        /// indicating the number of visits that should be made to each child.
        /// </summary>
        /// <param name="dualSelectorMode"></param>
        /// <param name="paramsSelect"></param>
        /// <param name="selectorID"></param>
        /// <param name="dynamicVLossBoost"></param>
        /// <param name="parentIsRoot"></param>
        /// <param name="parentN"></param>
        /// <param name="parentNInFlight"></param>
        /// <param name="qParent"></param>
        /// <param name="parentSumPVisited"></param>
        /// <param name="p"></param>
        /// <param name="w"></param>
        /// <param name="n"></param>
        /// <param name="nInFlight"></param>
        /// <param name="numChildren"></param>
        /// <param name="numVisitsToCompute"></param>
        /// <param name="outputScores"></param>
        /// <param name="outputChildVisitCounts"></param>
        public static void ScoreCalcMulti(bool dualSelectorMode, ParamsSelect paramsSelect,
                                          int selectorID, float dynamicVLossBoost,
                                          bool parentIsRoot, float parentN, float parentNInFlight,
                                          float qParent, float parentSumPVisited,
                                          Span <float> p, Span <float> w, Span <float> n, Span <float> nInFlight,
                                          int numChildren, int numVisitsToCompute,
                                          Span <float> outputScores, Span <short> outputChildVisitCounts)
        {
            // Saving output scores only makes sense when a single visit being computed
            Debug.Assert(!(outputScores != default && numVisitsToCompute > 1));

            Debug.Assert(p.Length == MAX_CHILDREN);
            Debug.Assert(w.Length == MAX_CHILDREN);
            Debug.Assert(n.Length == MAX_CHILDREN);
            Debug.Assert(nInFlight.Length == MAX_CHILDREN);
            Debug.Assert(numChildren <= MAX_CHILDREN);

            Debug.Assert(outputScores == default || outputScores.Length >= numChildren);
            Debug.Assert(outputChildVisitCounts.Length >= numChildren);

            int numBlocks = (numChildren / 8) + ((numChildren % 8 == 0) ? 0 : 1);

            float virtualLossMultiplier;

            if (ParamsSelect.VLossRelative)
            {
                virtualLossMultiplier = (qParent + paramsSelect.VirtualLossDefaultRelative + dynamicVLossBoost);
            }
            else
            {
                virtualLossMultiplier = paramsSelect.VirtualLossDefaultAbsolute;
            }

            float cpuctValue = paramsSelect.CalcCPUCT(parentIsRoot, dualSelectorMode, selectorID, parentN);

            // Compute qWhenNoChildren
            float fpuValue = -paramsSelect.CalcFPUValue(parentIsRoot);

            // TODO: to be more precise, parentSumPVisited should possibly be updated as we visit children
            bool  useFPUReduction = paramsSelect.GetFPUMode(parentIsRoot) == ParamsSelect.FPUType.Reduction;
            float qWhenNoChildren = useFPUReduction ? (+qParent + fpuValue * MathF.Sqrt(parentSumPVisited)) : fpuValue;

            if (parentIsRoot &&
                parentN > paramsSelect.RootCPUCTExtraMultiplierDivisor &&
                paramsSelect.RootCPUCTExtraMultiplierExponent != 0)
            {
                cpuctValue *= MathF.Pow(parentN / paramsSelect.RootCPUCTExtraMultiplierDivisor,
                                        paramsSelect.RootCPUCTExtraMultiplierExponent);
            }

            Compute(parentN, parentNInFlight, p, w, n, nInFlight, numChildren, numVisitsToCompute, outputScores,
                    outputChildVisitCounts, numBlocks, virtualLossMultiplier,
                    parentIsRoot ? paramsSelect.UCTRootNumeratorExponent : paramsSelect.UCTNonRootNumeratorExponent,
                    cpuctValue, qWhenNoChildren,
                    parentIsRoot ? paramsSelect.UCTRootDenominatorExponent : paramsSelect.UCTNonRootDenominatorExponent);
        }