static float ScoreCalc(float uctNumeratorPower, float uctDenominatorPower, float cpuct, int N, float p, float w, int n) { float q = w / n; float denominator = uctDenominatorPower == 1.0f ? (n + 1) : MathF.Pow(n + 1, uctDenominatorPower); float u = cpuct * p * (ParamsSelect.UCTParentMultiplier(N, uctNumeratorPower) / denominator); return(q + u); }
/// <summary> /// Worker method that coordinates looping over all the requested visits, /// including a performance optimization that attempts to /// detect the condition where many consecutive visits will be made to the same child. /// </summary> /// <param name="parentN"></param> /// <param name="parentNInFlight"></param> /// <param name="p"></param> /// <param name="w"></param> /// <param name="n"></param> /// <param name="nInFlight"></param> /// <param name="numChildren"></param> /// <param name="numVisitsToCompute"></param> /// <param name="outputScores"></param> /// <param name="outputChildVisitCounts"></param> /// <param name="numBlocks"></param> /// <param name="virtualLossMultiplier"></param> /// <param name="uctParentPower"></param> /// <param name="cpuctValue"></param> /// <param name="qWhenNoChildren"></param> /// <param name="uctDenominatorPower"></param> private static void Compute(float parentN, float parentNInFlight, Span <float> p, Span <float> w, Span <float> n, Span <float> nInFlight, int numChildren, int numVisitsToCompute, Span <float> outputScores, Span <short> outputChildVisitCounts, int numBlocks, float virtualLossMultiplier, float uctParentPower, float cpuctValue, float qWhenNoChildren, float uctDenominatorPower) { // Load the vectors that do not change Vector256 <float> vVirtualLossMultiplier = Vector256.Create(virtualLossMultiplier); // Make sure ThreadStatics are initialized, and get local copies for efficient access float[] localResultAVXScratch = childScoresTempBuffer; if (localResultAVXScratch == null) { InitializedForThread(); localResultAVXScratch = childScoresTempBuffer; } int numVisits = 0; while (numVisits < numVisitsToCompute) { // Get constant term handy float numVisitsByParentToChildren = parentNInFlight + ((parentN < 2) ? 1 : parentN - 1); float cpuctSqrtParentN = cpuctValue * ParamsSelect.UCTParentMultiplier(numVisitsByParentToChildren, uctParentPower); ComputeChildScores(p, w, n, nInFlight, numBlocks, qWhenNoChildren, vVirtualLossMultiplier, localResultAVXScratch, cpuctSqrtParentN, uctDenominatorPower); // Save back to output scores (if these were requested) if (outputScores != default) { Debug.Assert(numVisits <= 1); Span <float> scoresSpan = new Span <float>(localResultAVXScratch).Slice(0, numChildren); scoresSpan.CopyTo(outputScores); } // Find the best child and record this visit int maxIndex = ArrayUtils.IndexOfElementWithMaxValue(localResultAVXScratch, numChildren); // Update either 3 or 4 items items to reflect this visit parentNInFlight += 1; nInFlight[maxIndex] += 1; numVisits += 1; if (outputChildVisitCounts != default) { outputChildVisitCounts[maxIndex] += 1; } int numRemainingVisits = numVisitsToCompute - numVisits; // If we just found our first child we repeatedly try to // "jump ahead" by 10 visits at a time // as long as the first top child child remains the best child. // This optimizes for the common case that one child is dominant, // and empriically reduces the number of calls to ComputeChildScores by more than 30%. // // Note that would be possible to try this "jump ahead" technique // after not only the first visit, but in practice this did not improve performance. const int NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION = 10; if (numVisits == 1 && numRemainingVisits > NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION + 5) { int numSuccessfulVisitsAllIterations = 0; do { // Modify state to simulate additional visits to this top child float newNInFlight = nInFlight[maxIndex] += NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION; // Compute new child scores numVisitsByParentToChildren = newNInFlight + parentNInFlight + ((parentN < 2) ? 1 : parentN - 1); cpuctSqrtParentN = cpuctValue * ParamsSelect.UCTParentMultiplier(numVisitsByParentToChildren, uctParentPower); ComputeChildScores(p, w, n, nInFlight, numBlocks, qWhenNoChildren, vVirtualLossMultiplier, localResultAVXScratch, cpuctSqrtParentN, uctDenominatorPower); // Check if the best child was still the same if (maxIndex == ArrayUtils.IndexOfElementWithMaxValue(localResultAVXScratch, numChildren)) { // Child remained same, increment successful count numSuccessfulVisitsAllIterations += NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION; } else { // Failed, back out the last update to nInFlight and stop iterating nInFlight[maxIndex] -= NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION; break; } } while (numRemainingVisits - numSuccessfulVisitsAllIterations > NUM_ADDITIONAL_TRY_VISITS_PER_ITERATION); if (numSuccessfulVisitsAllIterations > 0) { // The nInFlight have already been kept continuously up to date // but need to update the other items to reflect these visits parentNInFlight += numSuccessfulVisitsAllIterations; numVisits += numSuccessfulVisitsAllIterations; if (outputChildVisitCounts != default) { outputChildVisitCounts[maxIndex] += (short)numSuccessfulVisitsAllIterations; } } } } }